diff --git a/Cargo.lock b/Cargo.lock index 9eb8a02..d90db49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,18 +24,18 @@ dependencies = [ [[package]] name = "ansi_term" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" dependencies = [ "winapi", ] [[package]] name = "anyhow" -version = "1.0.47" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d9ff5d688f1c13395289f67db01d4826b46dd694e7580accdc3e8430f2d98e" +checksum = "8b26702f315f53b6071259e15dd9d64528213b44d61de1ec926eca7715d62203" [[package]] name = "atty" @@ -60,43 +60,12 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" -[[package]] -name = "bindgen" -version = "0.59.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453c49e5950bb0eb63bb3df640e31618846c89d5b7faa54040d76e98e0134375" -dependencies = [ - "bitflags", - "cexpr", - "clang-sys", - "lazy_static", - "lazycell", - "peeking_take_while", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", -] - [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" -[[package]] -name = "bitvec" -version = "0.19.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55f93d0ef3363c364d5976646a38f04cf67cfe1d4c8d160cdea02cab2c116b33" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] - [[package]] name = "byteorder" version = "1.4.3" @@ -115,15 +84,6 @@ version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22a9137b95ea06864e018375b72adfb7db6e6f68cfc8df5a04d00288050485ee" -[[package]] -name = "cexpr" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db507a7679252d2276ed0dd8113c6875ec56d3089f9225b2b42c30cc1f8e5c89" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -136,21 +96,11 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e" -[[package]] -name = "clang-sys" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa66045b9cb23c2e9c1520732030608b02ee07e5cfaa5a521ec15ded7fa24c90" -dependencies = [ - "glob", - "libc", -] - [[package]] name = "clap" -version = "2.33.3" +version = "2.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" dependencies = [ "ansi_term", "atty", @@ -288,40 +238,33 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "funty" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7" - [[package]] name = "futures-channel" -version = "0.3.17" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" +checksum = "ba3dda0b6588335f360afc675d0564c17a77a2bda81ca178a4b6081bd86c7f0b" dependencies = [ "futures-core", ] [[package]] name = "futures-core" -version = "0.3.17" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" +checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" [[package]] name = "futures-task" -version = "0.3.17" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" +checksum = "6ee7c6485c30167ce4dfb83ac568a849fe53274c831081476ee13e0dce1aad72" [[package]] name = "futures-util" -version = "0.3.17" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" +checksum = "d9b5cf40b47a271f77a8b1bec03ca09044d99d2372c0de244e66430761127164" dependencies = [ - "autocfg", "futures-core", "futures-task", "pin-project-lite", @@ -349,12 +292,6 @@ dependencies = [ "wasi", ] -[[package]] -name = "glob" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" - [[package]] name = "hashbrown" version = "0.11.2" @@ -409,7 +346,7 @@ checksum = "1323096b05d41827dadeaee54c9981958c0f94e670bc94ed80037d1a7b8b186b" dependencies = [ "bytes", "fnv", - "itoa", + "itoa 0.4.8", ] [[package]] @@ -446,9 +383,9 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.15" +version = "0.14.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436ec0091e4f20e655156a30a0df3770fe2900aa301e548e08446ec794b6953c" +checksum = "b7ec3e62bdc98a2f0393a5048e4c30ef659440ea6e0e572965103e72bd836f55" dependencies = [ "bytes", "futures-channel", @@ -458,7 +395,7 @@ dependencies = [ "http-body", "httparse", "httpdate", - "itoa", + "itoa 0.4.8", "pin-project-lite", "socket2", "tokio", @@ -510,29 +447,29 @@ version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" +[[package]] +name = "itoa" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" + [[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "libc" -version = "0.2.108" +version = "0.2.112" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8521a1b57e76b1ec69af7599e75e38e7b7fad6610f037db8c79b127201b5d119" +checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125" [[package]] name = "libsqlite3-sys" -version = "0.23.1" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abd5850c449b40bacb498b2bbdfaff648b1b055630073ba8db499caf2d0ea9f2" +checksum = "d2cafc7c74096c336d9d27145f7ebd4f4b6f95ba16aa5a282387267e6925cb58" dependencies = [ "pkg-config", "vcpkg", @@ -570,9 +507,9 @@ checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" [[package]] name = "memoffset" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" dependencies = [ "autocfg", ] @@ -601,9 +538,9 @@ dependencies = [ [[package]] name = "netlink-packet-core" -version = "0.2.4" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac48279d5062bdf175bdbcb6b58ff1d6b0ecd54b951f7a0ff4bc0550fe903ccb" +checksum = "8349128e95f5dabcb8a18587ad06b3ca7993e90c0c360b4a2abac0313ebce727" dependencies = [ "anyhow", "byteorder", @@ -612,10 +549,23 @@ dependencies = [ ] [[package]] -name = "netlink-packet-route" -version = "0.8.0" +name = "netlink-packet-generic" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76aed5d3b6e3929713bf1e1334a11fd65180b6d9f5d7c8572664c48b122604f8" +checksum = "8678ffbbfef3dd88acbe85ed31d32f0de0a100854ee7d47fe5b250f81857a23b" +dependencies = [ + "anyhow", + "byteorder", + "libc", + "netlink-packet-core", + "netlink-packet-utils", +] + +[[package]] +name = "netlink-packet-route" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb5d54077de7c0904111e1d19b661b8cfccbc23d9ce5b6dbcc7362721e6e552" dependencies = [ "anyhow", "bitflags", @@ -627,9 +577,9 @@ dependencies = [ [[package]] name = "netlink-packet-utils" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fcfb6f758b66e964b2339596d94078218d96aad5b32003e8e2a1d23c27a6784" +checksum = "0a008a56eceb0cab06739c7f37f15bda27f1147a14d0e7136e8c913b94f1441d" dependencies = [ "anyhow", "byteorder", @@ -638,11 +588,36 @@ dependencies = [ ] [[package]] -name = "netlink-sys" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f48ea34ea0678719815c3753155067212f853ad2d8ef4a49167bae7f7c254188" +name = "netlink-packet-wireguard" +version = "0.1.1" +source = "git+https://github.com/mcginty/netlink?branch=wireguard-fixes#2b60e310ede5fa4c80c00874c19ee755b1bc8249" dependencies = [ + "anyhow", + "byteorder", + "libc", + "log", + "netlink-packet-generic", + "netlink-packet-utils", +] + +[[package]] +name = "netlink-request" +version = "0.1.0" +dependencies = [ + "netlink-packet-core", + "netlink-packet-generic", + "netlink-packet-route", + "netlink-packet-wireguard", + "netlink-sys", +] + +[[package]] +name = "netlink-sys" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed51a4602bb956eefef0ebc15f478bf9732fa3cc706e0a37112e654f41c5b92c" +dependencies = [ + "bytes", "libc", "log", ] @@ -660,18 +635,6 @@ dependencies = [ "memoffset", ] -[[package]] -name = "nom" -version = "6.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7413f999671bd4745a7b624bd370a569fb6bc574b23c83a3c5ed2e453f3d5e2" -dependencies = [ - "bitvec", - "funty", - "memchr", - "version_check", -] - [[package]] name = "ntapi" version = "0.3.6" @@ -728,12 +691,6 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0744126afe1a6dd7f394cb50a716dbe086cb06e255e53d8d0185d82828358fb5" -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - [[package]] name = "percent-encoding" version = "2.1.0" @@ -754,9 +711,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.22" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12295df4f294471248581bc09bef3c38a5e46f1e36d6a37353621a0c6c357e1f" +checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe" [[package]] name = "ppv-lite86" @@ -800,9 +757,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba508cc11742c0dc5c1659771673afbab7a0efab23aa17e854cbab0837ed0b43" +checksum = "fb37d2df5df740e582f28f8560cf425f52bb267d872fe58358eadb554909f07a" dependencies = [ "unicode-xid", ] @@ -826,12 +783,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radium" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "941ba9d78d8e2f7ce474c015eea4d9c6d25b6a3327f9832ee29a4de27f91bbb8" - [[package]] name = "rand" version = "0.8.4" @@ -909,9 +860,9 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.26.1" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a82b0b91fad72160c56bf8da7a549b25d7c31109f52cc1437eac4c0ad2550a7" +checksum = "4ba4d3462c8b2e4d7f4fcfcf2b296dc6b65404fbbc7b63daa37fd485c149daf7" dependencies = [ "bitflags", "fallible-iterator", @@ -922,17 +873,11 @@ dependencies = [ "smallvec", ] -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "ryu" -version = "1.0.5" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" [[package]] name = "scopeguard" @@ -942,18 +887,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.130" +version = "1.0.131" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913" +checksum = "b4ad69dfbd3e45369132cc64e6748c2d65cdfb001a2b1c232d128b4ad60561c1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.130" +version = "1.0.131" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7bc1a1ab1961464eae040d96713baa5a724a8152c1222492465b54322ec508b" +checksum = "b710a83c4e0dff6a3d511946b95274ad9ca9e5d3ae497b63fda866ac955358d2" dependencies = [ "proc-macro2", "quote", @@ -962,11 +907,11 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.71" +version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "063bf466a64011ac24040a49009724ee60a57da1b437617ceb32e53ad61bfb19" +checksum = "bcbd0344bc6533bc7ec56df11d42fb70f1b912351c0825ccb7211b59d8af7cf5" dependencies = [ - "itoa", + "itoa 1.0.1", "ryu", "serde", ] @@ -1019,6 +964,7 @@ dependencies = [ "log", "netlink-packet-core", "netlink-packet-route", + "netlink-request", "netlink-sys", "nix", "publicip", @@ -1028,15 +974,8 @@ dependencies = [ "toml", "url", "wireguard-control", - "wireguard-control-sys", ] -[[package]] -name = "shlex" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" - [[package]] name = "smallvec" version = "1.7.0" @@ -1091,21 +1030,15 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2afee18b8beb5a596ecb4a2dce128c719b4ba399d34126b9e4396e3f9860966" +checksum = "8daf5dd0bb60cbd4137b1b587d2fc0ae729bc07cf01cd70b36a1ed5ade3b9d59" dependencies = [ "proc-macro2", "quote", "unicode-xid", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - [[package]] name = "tempfile" version = "3.2.0" @@ -1395,25 +1328,15 @@ dependencies = [ "curve25519-dalek", "hex", "libc", + "netlink-packet-core", + "netlink-packet-generic", + "netlink-packet-route", + "netlink-packet-wireguard", + "netlink-request", + "netlink-sys", "rand_core", - "wireguard-control-sys", ] -[[package]] -name = "wireguard-control-sys" -version = "1.5.2" -dependencies = [ - "bindgen", - "cc", - "libc", -] - -[[package]] -name = "wyz" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85e60b0d1b5f99db2556934e21937020776a5d31520bf169e851ac44e6420214" - [[package]] name = "zeroize" version = "1.4.3" diff --git a/Cargo.toml b/Cargo.toml index 05e25f8..2078ca2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["server", "client", "hostsfile", "shared", "publicip"] +members = ["server", "client", "hostsfile", "shared", "publicip", "netlink-request"] [profile.release] codegen-units = 1 diff --git a/client/src/main.rs b/client/src/main.rs index 5a86b52..e1af1f9 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -522,7 +522,10 @@ fn fetch( ); } - log::info!("bringing up interface {}.", interface.as_str_lossy().yellow()); + log::info!( + "bringing up interface {}.", + interface.as_str_lossy().yellow() + ); let resolved_endpoint = config .server .external_endpoint @@ -543,7 +546,10 @@ fn fetch( .with_str(interface.to_string())?; } - log::info!("fetching state for {} from server...", interface.as_str_lossy().yellow()); + log::info!( + "fetching state for {} from server...", + interface.as_str_lossy().yellow() + ); let mut store = DataStore::open_or_create(&opts.data_dir, interface)?; let api = Api::new(&config.server); let State { peers, cidrs } = api.http("GET", "/user/state")?; @@ -978,11 +984,22 @@ fn show(opts: &Opts, short: bool, tree: bool, interface: Option) -> R } for (device_info, store) in devices { + let public_key = match &device_info.public_key { + Some(key) => key.to_base64(), + None => { + log::warn!( + "network {} is missing public key.", + device_info.name.to_string().yellow() + ); + continue; + }, + }; + let peers = store.peers(); let cidrs = store.cidrs(); let me = peers .iter() - .find(|p| p.public_key == device_info.public_key.as_ref().unwrap().to_base64()) + .find(|p| p.public_key == public_key) .ok_or_else(|| anyhow!("missing peer info"))?; let mut peer_states = device_info diff --git a/client/src/util.rs b/client/src/util.rs index 79772ad..58bb9a6 100644 --- a/client/src/util.rs +++ b/client/src/util.rs @@ -3,8 +3,8 @@ use colored::*; use indoc::eprintdoc; use log::{Level, LevelFilter}; use serde::{de::DeserializeOwned, Serialize}; -use shared::{interface_config::ServerInfo, PeerDiff, INNERNET_PUBKEY_HEADER, Interface}; -use std::{io, path::Path, time::Duration, ffi::OsStr}; +use shared::{interface_config::ServerInfo, Interface, PeerDiff, INNERNET_PUBKEY_HEADER}; +use std::{ffi::OsStr, io, path::Path, time::Duration}; use ureq::{Agent, AgentBuilder}; static LOGGER: Logger = Logger; @@ -176,18 +176,19 @@ pub fn all_installed(config_dir: &Path) -> Result, std::io::Error .into_iter() .collect::>()?; - let installed: Vec<_> = entries.into_iter() + let installed: Vec<_> = entries + .into_iter() .filter(|entry| match entry.file_type() { Ok(f) => f.is_file(), - _ => false + _ => false, }) .filter_map(|entry| { let path = entry.path(); match (path.extension(), path.file_stem()) { (Some(extension), Some(stem)) if extension == OsStr::new("conf") => { Some(stem.to_string_lossy().to_string()) - } - _ => None + }, + _ => None, } }) .map(|name| name.parse()) diff --git a/netlink-request/Cargo.toml b/netlink-request/Cargo.toml new file mode 100644 index 0000000..36e1cae --- /dev/null +++ b/netlink-request/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "netlink-request" +version = "0.1.0" +edition = "2021" + +[target.'cfg(target_os = "linux")'.dependencies] +netlink-sys = "0.8" +netlink-packet-core = "0.4" +netlink-packet-generic = "0.3" +netlink-packet-route = "0.10" +netlink-packet-wireguard = { git = "https://github.com/mcginty/netlink", branch = "wireguard-fixes" } diff --git a/netlink-request/src/lib.rs b/netlink-request/src/lib.rs new file mode 100644 index 0000000..3259d57 --- /dev/null +++ b/netlink-request/src/lib.rs @@ -0,0 +1,125 @@ +#[cfg(target_os = "linux")] +mod linux { + use netlink_packet_core::{ + NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK, + NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, + }; + use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlFamily, GenlMessage, + }; + use netlink_packet_route::RtnlMessage; + use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket}; + use std::{fmt::Debug, io}; + + macro_rules! get_nla_value { + ($nlas:expr, $e:ident, $v:ident) => { + $nlas.iter().find_map(|attr| match attr { + $e::$v(value) => Some(value), + _ => None, + }) + }; + } + + pub fn netlink_request_genl( + mut message: GenlMessage, + flags: Option, + ) -> Result>>, io::Error> + where + F: GenlFamily + Clone + Debug + Eq, + GenlMessage: Clone + Debug + Eq + NetlinkSerializable + NetlinkDeserializable, + { + if message.family_id() == 0 { + let genlmsg: GenlMessage = GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyName(F::family_name().to_string())], + }); + let responses = + netlink_request_genl::(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?; + + match responses.get(0) { + Some(NetlinkMessage { + payload: + NetlinkPayload::InnerMessage(GenlMessage { + payload: GenlCtrl { nlas, .. }, + .. + }), + .. + }) => { + let family_id = get_nla_value!(nlas, GenlCtrlAttrs, FamilyId) + .ok_or_else(|| io::ErrorKind::NotFound)?; + message.set_resolved_family_id(*family_id); + }, + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Unexpected netlink payload", + )) + }, + }; + } + netlink_request(message, flags, NETLINK_GENERIC) + } + + pub fn netlink_request_rtnl( + message: RtnlMessage, + flags: Option, + ) -> Result>, io::Error> { + netlink_request(message, flags, NETLINK_ROUTE) + } + + pub fn netlink_request( + message: I, + flags: Option, + socket: isize, + ) -> Result>, io::Error> + where + NetlinkPayload: From, + I: Clone + Debug + Eq + NetlinkSerializable + NetlinkDeserializable, + { + let mut req = NetlinkMessage::from(message); + req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); + req.finalize(); + let mut buf = [0; 4096]; + req.serialize(&mut buf); + let len = req.buffer_len(); + + let socket = Socket::new(socket)?; + let kernel_addr = netlink_sys::SocketAddr::new(0, 0); + socket.connect(&kernel_addr)?; + let n_sent = socket.send(&buf[..len], 0)?; + if n_sent != len { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "failed to send netlink request", + )); + } + + let mut responses = vec![]; + loop { + let n_received = socket.recv(&mut &mut buf[..], 0)?; + let mut offset = 0; + loop { + let bytes = &buf[offset..]; + let response = NetlinkMessage::::deserialize(bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + responses.push(response.clone()); + match response.payload { + // We've parsed all parts of the response and can leave the loop. + NetlinkPayload::Ack(_) | NetlinkPayload::Done => return Ok(responses), + NetlinkPayload::Error(e) => return Err(e.into()), + _ => {}, + } + offset += response.header.length as usize; + if offset == n_received || response.header.length == 0 { + // We've fully parsed the datagram, but there may be further datagrams + // with additional netlink response parts. + break; + } + } + } + } +} + +#[cfg(target_os = "linux")] +pub use linux::{netlink_request, netlink_request_genl, netlink_request_rtnl}; diff --git a/shared/Cargo.toml b/shared/Cargo.toml index e7b902b..f13058c 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -25,10 +25,10 @@ url = "2" wireguard-control = { path = "../wireguard-control" } [target.'cfg(target_os = "linux")'.dependencies] -netlink-sys = "0.7" -netlink-packet-core = "0.2" -netlink-packet-route = "0.8" -wireguard-control-sys = { path = "../wireguard-control-sys" } +netlink-sys = "0.8" +netlink-packet-core = "0.4" +netlink-packet-route = "0.10" +netlink-request = { path = "../netlink-request" } [target.'cfg(target_os = "macos")'.dependencies] nix = "0.23" diff --git a/shared/src/netlink.rs b/shared/src/netlink.rs index 4e329dc..bb97880 100644 --- a/shared/src/netlink.rs +++ b/shared/src/netlink.rs @@ -1,7 +1,5 @@ use ipnetwork::IpNetwork; -use netlink_packet_core::{ - NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, -}; +use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_REQUEST}; use netlink_packet_route::{ address, constants::*, @@ -9,7 +7,7 @@ use netlink_packet_route::{ route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, RouteHeader, RouteMessage, RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN, }; -use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; +use netlink_request::netlink_request_rtnl; use std::{io, net::IpAddr}; use wireguard_control::InterfaceName; @@ -23,55 +21,6 @@ fn if_nametoindex(interface: &InterfaceName) -> Result { } } -fn netlink_call( - message: RtnlMessage, - flags: Option, -) -> Result>, io::Error> { - let mut req = NetlinkMessage::from(message); - req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); - req.finalize(); - let mut buf = [0; 4096]; - req.serialize(&mut buf); - let len = req.buffer_len(); - - log::trace!("netlink request: {:?}", req); - let socket = Socket::new(NETLINK_ROUTE)?; - let kernel_addr = SocketAddr::new(0, 0); - socket.connect(&kernel_addr)?; - let n_sent = socket.send(&buf[..len], 0)?; - if n_sent != len { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "failed to send netlink request", - )); - } - - let mut responses = vec![]; - loop { - let n_received = socket.recv(&mut buf[..], 0)?; - let mut offset = 0; - loop { - let bytes = &buf[offset..]; - let response = NetlinkMessage::::deserialize(bytes) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - responses.push(response.clone()); - log::trace!("netlink response: {:?}", response); - match response.payload { - // We've parsed all parts of the response and can leave the loop. - NetlinkPayload::Ack(_) | NetlinkPayload::Done => return Ok(responses), - NetlinkPayload::Error(e) => return Err(e.into()), - _ => {}, - } - offset += response.header.length as usize; - if offset == n_received || response.header.length == 0 { - // We've fully parsed the datagram, but there may be further datagrams - // with additional netlink response parts. - break; - } - } - } -} - pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> { let index = if_nametoindex(interface)?; let message = LinkMessage { @@ -82,7 +31,8 @@ pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> { }, nlas: vec![link::nlas::Nla::Mtu(mtu)], }; - netlink_call(RtnlMessage::SetLink(message), None)?; + netlink_request_rtnl(RtnlMessage::SetLink(message), None)?; + log::debug!("set interface {} up with mtu {}", interface, mtu); Ok(()) } @@ -114,10 +64,11 @@ pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), io::Er }, nlas, }; - netlink_call( + netlink_request_rtnl( RtnlMessage::NewAddress(message), Some(NLM_F_REQUEST | NLM_F_ACK | NLM_F_REPLACE | NLM_F_CREATE), )?; + log::debug!("set address {} on interface {}", addr, interface); Ok(()) } @@ -140,15 +91,21 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result Ok(true), - Err(e) if e.kind() == io::ErrorKind::AlreadyExists => Ok(false), + match netlink_request_rtnl(RtnlMessage::NewRoute(message), None) { + Ok(_) => { + log::debug!("added route {} to interface {}", cidr, interface); + Ok(true) + }, + Err(e) if e.kind() == io::ErrorKind::AlreadyExists => { + log::debug!("route {} already existed.", cidr); + Ok(false) + }, Err(e) => Err(e), } } fn get_links() -> Result, io::Error> { - let link_responses = netlink_call( + let link_responses = netlink_request_rtnl( RtnlMessage::GetLink(LinkMessage::default()), Some(NLM_F_DUMP | NLM_F_REQUEST), )?; @@ -181,7 +138,7 @@ fn get_links() -> Result, io::Error> { pub fn get_local_addrs() -> Result, io::Error> { let links = get_links()?; - let addr_responses = netlink_call( + let addr_responses = netlink_request_rtnl( RtnlMessage::GetAddress(AddressMessage::default()), Some(NLM_F_DUMP | NLM_F_REQUEST), )?; diff --git a/wireguard-control-sys/.gitignore b/wireguard-control-sys/.gitignore deleted file mode 100644 index 6936990..0000000 --- a/wireguard-control-sys/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/target -**/*.rs.bk -Cargo.lock diff --git a/wireguard-control-sys/Cargo.toml b/wireguard-control-sys/Cargo.toml deleted file mode 100644 index 0606b26..0000000 --- a/wireguard-control-sys/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -authors = ["K900 ", "Jake McGinty "] -categories = ["external-ffi-bindings", "os::unix-apis"] -description = "Raw bindings to the WireGuard embeddable C library" -license = "LGPL-2.1-or-later" -name = "wireguard-control-sys" -readme = "README.md" -repository = "https://github.com/tonarino/innernet" -version = "1.5.2" - -[dependencies] -libc = "0.2" - -[features] -buildtime_bindgen = ["bindgen"] - -[build-dependencies] -bindgen = { version = "0", default-features = false, optional = true } -cc = "1.0" diff --git a/wireguard-control-sys/README.md b/wireguard-control-sys/README.md deleted file mode 100644 index bd0017d..0000000 --- a/wireguard-control-sys/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# `wireguard-control-sys` - -A low-level FFI around the [`embaddable-wg-library`](https://git.zx2c4.com/wireguard-tools/tree/contrib/embeddable-wg-library) WireGuard C library, which in turn communicates with the Linux kernel WireGuard via Netlink. - -You *probably* want to use the [`wireguard-control`](https://crates.io/crates/wireguard-control) crate instead. diff --git a/wireguard-control-sys/bindgen-bindings/bindings.rs b/wireguard-control-sys/bindgen-bindings/bindings.rs deleted file mode 100644 index a45a3fd..0000000 --- a/wireguard-control-sys/bindgen-bindings/bindings.rs +++ /dev/null @@ -1,968 +0,0 @@ -/* automatically generated by rust-bindgen 0.59.1 */ - -pub type __uint8_t = ::std::os::raw::c_uchar; -pub type __uint16_t = ::std::os::raw::c_ushort; -pub type __uint32_t = ::std::os::raw::c_uint; -pub type __int64_t = ::std::os::raw::c_long; -pub type __uint64_t = ::std::os::raw::c_ulong; -pub type sa_family_t = ::std::os::raw::c_ushort; -#[repr(C)] -#[derive(Debug, Default, Copy, Clone)] -pub struct sockaddr { - pub sa_family: sa_family_t, - pub sa_data: [::std::os::raw::c_char; 14usize], -} -#[test] -fn bindgen_test_layout_sockaddr() { - assert_eq!( - ::std::mem::size_of::(), - 16usize, - concat!("Size of: ", stringify!(sockaddr)) - ); - assert_eq!( - ::std::mem::align_of::(), - 2usize, - concat!("Alignment of ", stringify!(sockaddr)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sa_family as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(sockaddr), - "::", - stringify!(sa_family) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sa_data as *const _ as usize }, - 2usize, - concat!( - "Offset of field: ", - stringify!(sockaddr), - "::", - stringify!(sa_data) - ) - ); -} -pub type in_addr_t = u32; -#[repr(C)] -#[derive(Debug, Default, Copy, Clone)] -pub struct in_addr { - pub s_addr: in_addr_t, -} -#[test] -fn bindgen_test_layout_in_addr() { - assert_eq!( - ::std::mem::size_of::(), - 4usize, - concat!("Size of: ", stringify!(in_addr)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(in_addr)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).s_addr as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(in_addr), - "::", - stringify!(s_addr) - ) - ); -} -pub type in_port_t = u16; -#[repr(C)] -#[derive(Copy, Clone)] -pub struct in6_addr { - pub __in6_u: in6_addr__bindgen_ty_1, -} -#[repr(C)] -#[derive(Copy, Clone)] -pub union in6_addr__bindgen_ty_1 { - pub __u6_addr8: [u8; 16usize], - pub __u6_addr16: [u16; 8usize], - pub __u6_addr32: [u32; 4usize], -} -#[test] -fn bindgen_test_layout_in6_addr__bindgen_ty_1() { - assert_eq!( - ::std::mem::size_of::(), - 16usize, - concat!("Size of: ", stringify!(in6_addr__bindgen_ty_1)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(in6_addr__bindgen_ty_1)) - ); - assert_eq!( - unsafe { - &(*(::std::ptr::null::())).__u6_addr8 as *const _ as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(in6_addr__bindgen_ty_1), - "::", - stringify!(__u6_addr8) - ) - ); - assert_eq!( - unsafe { - &(*(::std::ptr::null::())).__u6_addr16 as *const _ as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(in6_addr__bindgen_ty_1), - "::", - stringify!(__u6_addr16) - ) - ); - assert_eq!( - unsafe { - &(*(::std::ptr::null::())).__u6_addr32 as *const _ as usize - }, - 0usize, - concat!( - "Offset of field: ", - stringify!(in6_addr__bindgen_ty_1), - "::", - stringify!(__u6_addr32) - ) - ); -} -impl Default for in6_addr__bindgen_ty_1 { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for in6_addr__bindgen_ty_1 { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "in6_addr__bindgen_ty_1 {{ union }}") - } -} -#[test] -fn bindgen_test_layout_in6_addr() { - assert_eq!( - ::std::mem::size_of::(), - 16usize, - concat!("Size of: ", stringify!(in6_addr)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(in6_addr)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).__in6_u as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(in6_addr), - "::", - stringify!(__in6_u) - ) - ); -} -impl Default for in6_addr { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for in6_addr { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "in6_addr {{ __in6_u: {:?} }}", self.__in6_u) - } -} -#[repr(C)] -#[derive(Debug, Default, Copy, Clone)] -pub struct sockaddr_in { - pub sin_family: sa_family_t, - pub sin_port: in_port_t, - pub sin_addr: in_addr, - pub sin_zero: [::std::os::raw::c_uchar; 8usize], -} -#[test] -fn bindgen_test_layout_sockaddr_in() { - assert_eq!( - ::std::mem::size_of::(), - 16usize, - concat!("Size of: ", stringify!(sockaddr_in)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(sockaddr_in)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin_family as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in), - "::", - stringify!(sin_family) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin_port as *const _ as usize }, - 2usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in), - "::", - stringify!(sin_port) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin_addr as *const _ as usize }, - 4usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in), - "::", - stringify!(sin_addr) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin_zero as *const _ as usize }, - 8usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in), - "::", - stringify!(sin_zero) - ) - ); -} -#[repr(C)] -#[derive(Copy, Clone)] -pub struct sockaddr_in6 { - pub sin6_family: sa_family_t, - pub sin6_port: in_port_t, - pub sin6_flowinfo: u32, - pub sin6_addr: in6_addr, - pub sin6_scope_id: u32, -} -#[test] -fn bindgen_test_layout_sockaddr_in6() { - assert_eq!( - ::std::mem::size_of::(), - 28usize, - concat!("Size of: ", stringify!(sockaddr_in6)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(sockaddr_in6)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin6_family as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in6), - "::", - stringify!(sin6_family) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin6_port as *const _ as usize }, - 2usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in6), - "::", - stringify!(sin6_port) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin6_flowinfo as *const _ as usize }, - 4usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in6), - "::", - stringify!(sin6_flowinfo) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin6_addr as *const _ as usize }, - 8usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in6), - "::", - stringify!(sin6_addr) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).sin6_scope_id as *const _ as usize }, - 24usize, - concat!( - "Offset of field: ", - stringify!(sockaddr_in6), - "::", - stringify!(sin6_scope_id) - ) - ); -} -impl Default for sockaddr_in6 { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for sockaddr_in6 { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write ! (f , "sockaddr_in6 {{ sin6_family: {:?}, sin6_port: {:?}, sin6_flowinfo: {:?}, sin6_addr: {:?}, sin6_scope_id: {:?} }}" , self . sin6_family , self . sin6_port , self . sin6_flowinfo , self . sin6_addr , self . sin6_scope_id) - } -} -pub type wg_key = [u8; 32usize]; -pub type wg_key_b64_string = [::std::os::raw::c_char; 45usize]; -#[repr(C)] -#[derive(Debug, Default, Copy, Clone)] -pub struct timespec64 { - pub tv_sec: i64, - pub tv_nsec: i64, -} -#[test] -fn bindgen_test_layout_timespec64() { - assert_eq!( - ::std::mem::size_of::(), - 16usize, - concat!("Size of: ", stringify!(timespec64)) - ); - assert_eq!( - ::std::mem::align_of::(), - 8usize, - concat!("Alignment of ", stringify!(timespec64)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).tv_sec as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(timespec64), - "::", - stringify!(tv_sec) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).tv_nsec as *const _ as usize }, - 8usize, - concat!( - "Offset of field: ", - stringify!(timespec64), - "::", - stringify!(tv_nsec) - ) - ); -} -#[repr(C)] -#[derive(Copy, Clone)] -pub struct wg_allowedip { - pub family: u16, - pub __bindgen_anon_1: wg_allowedip__bindgen_ty_1, - pub cidr: u8, - pub next_allowedip: *mut wg_allowedip, -} -#[repr(C)] -#[derive(Copy, Clone)] -pub union wg_allowedip__bindgen_ty_1 { - pub ip4: in_addr, - pub ip6: in6_addr, -} -#[test] -fn bindgen_test_layout_wg_allowedip__bindgen_ty_1() { - assert_eq!( - ::std::mem::size_of::(), - 16usize, - concat!("Size of: ", stringify!(wg_allowedip__bindgen_ty_1)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(wg_allowedip__bindgen_ty_1)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).ip4 as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_allowedip__bindgen_ty_1), - "::", - stringify!(ip4) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).ip6 as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_allowedip__bindgen_ty_1), - "::", - stringify!(ip6) - ) - ); -} -impl Default for wg_allowedip__bindgen_ty_1 { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for wg_allowedip__bindgen_ty_1 { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "wg_allowedip__bindgen_ty_1 {{ union }}") - } -} -#[test] -fn bindgen_test_layout_wg_allowedip() { - assert_eq!( - ::std::mem::size_of::(), - 32usize, - concat!("Size of: ", stringify!(wg_allowedip)) - ); - assert_eq!( - ::std::mem::align_of::(), - 8usize, - concat!("Alignment of ", stringify!(wg_allowedip)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).family as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_allowedip), - "::", - stringify!(family) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).cidr as *const _ as usize }, - 20usize, - concat!( - "Offset of field: ", - stringify!(wg_allowedip), - "::", - stringify!(cidr) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).next_allowedip as *const _ as usize }, - 24usize, - concat!( - "Offset of field: ", - stringify!(wg_allowedip), - "::", - stringify!(next_allowedip) - ) - ); -} -impl Default for wg_allowedip { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for wg_allowedip { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write ! (f , "wg_allowedip {{ family: {:?}, __bindgen_anon_1: {:?}, cidr: {:?}, next_allowedip: {:?} }}" , self . family , self . __bindgen_anon_1 , self . cidr , self . next_allowedip) - } -} -impl wg_peer_flags { - pub const WGPEER_REMOVE_ME: wg_peer_flags = wg_peer_flags(1); -} -impl wg_peer_flags { - pub const WGPEER_REPLACE_ALLOWEDIPS: wg_peer_flags = wg_peer_flags(2); -} -impl wg_peer_flags { - pub const WGPEER_HAS_PUBLIC_KEY: wg_peer_flags = wg_peer_flags(4); -} -impl wg_peer_flags { - pub const WGPEER_HAS_PRESHARED_KEY: wg_peer_flags = wg_peer_flags(8); -} -impl wg_peer_flags { - pub const WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL: wg_peer_flags = wg_peer_flags(16); -} -impl ::std::ops::BitOr for wg_peer_flags { - type Output = Self; - #[inline] - fn bitor(self, other: Self) -> Self { - wg_peer_flags(self.0 | other.0) - } -} -impl ::std::ops::BitOrAssign for wg_peer_flags { - #[inline] - fn bitor_assign(&mut self, rhs: wg_peer_flags) { - self.0 |= rhs.0; - } -} -impl ::std::ops::BitAnd for wg_peer_flags { - type Output = Self; - #[inline] - fn bitand(self, other: Self) -> Self { - wg_peer_flags(self.0 & other.0) - } -} -impl ::std::ops::BitAndAssign for wg_peer_flags { - #[inline] - fn bitand_assign(&mut self, rhs: wg_peer_flags) { - self.0 &= rhs.0; - } -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct wg_peer_flags(pub ::std::os::raw::c_uint); -#[repr(C)] -#[derive(Copy, Clone)] -pub union wg_endpoint { - pub addr: sockaddr, - pub addr4: sockaddr_in, - pub addr6: sockaddr_in6, -} -#[test] -fn bindgen_test_layout_wg_endpoint() { - assert_eq!( - ::std::mem::size_of::(), - 28usize, - concat!("Size of: ", stringify!(wg_endpoint)) - ); - assert_eq!( - ::std::mem::align_of::(), - 4usize, - concat!("Alignment of ", stringify!(wg_endpoint)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).addr as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_endpoint), - "::", - stringify!(addr) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).addr4 as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_endpoint), - "::", - stringify!(addr4) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).addr6 as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_endpoint), - "::", - stringify!(addr6) - ) - ); -} -impl Default for wg_endpoint { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for wg_endpoint { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "wg_endpoint {{ union }}") - } -} -#[repr(C)] -#[derive(Copy, Clone)] -pub struct wg_peer { - pub flags: wg_peer_flags, - pub public_key: wg_key, - pub preshared_key: wg_key, - pub endpoint: wg_endpoint, - pub last_handshake_time: timespec64, - pub rx_bytes: u64, - pub tx_bytes: u64, - pub persistent_keepalive_interval: u16, - pub first_allowedip: *mut wg_allowedip, - pub last_allowedip: *mut wg_allowedip, - pub next_peer: *mut wg_peer, -} -#[test] -fn bindgen_test_layout_wg_peer() { - assert_eq!( - ::std::mem::size_of::(), - 160usize, - concat!("Size of: ", stringify!(wg_peer)) - ); - assert_eq!( - ::std::mem::align_of::(), - 8usize, - concat!("Alignment of ", stringify!(wg_peer)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).flags as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(flags) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).public_key as *const _ as usize }, - 4usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(public_key) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).preshared_key as *const _ as usize }, - 36usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(preshared_key) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).endpoint as *const _ as usize }, - 68usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(endpoint) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).last_handshake_time as *const _ as usize }, - 96usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(last_handshake_time) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).rx_bytes as *const _ as usize }, - 112usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(rx_bytes) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).tx_bytes as *const _ as usize }, - 120usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(tx_bytes) - ) - ); - assert_eq!( - unsafe { - &(*(::std::ptr::null::())).persistent_keepalive_interval as *const _ as usize - }, - 128usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(persistent_keepalive_interval) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).first_allowedip as *const _ as usize }, - 136usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(first_allowedip) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).last_allowedip as *const _ as usize }, - 144usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(last_allowedip) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).next_peer as *const _ as usize }, - 152usize, - concat!( - "Offset of field: ", - stringify!(wg_peer), - "::", - stringify!(next_peer) - ) - ); -} -impl Default for wg_peer { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -impl ::std::fmt::Debug for wg_peer { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write ! (f , "wg_peer {{ flags: {:?}, public_key: [{}], preshared_key: [{}], endpoint: {:?}, last_handshake_time: {:?}, rx_bytes: {:?}, tx_bytes: {:?}, persistent_keepalive_interval: {:?}, first_allowedip: {:?}, last_allowedip: {:?}, next_peer: {:?} }}" , self . flags , self . public_key . iter () . enumerate () . map (| (i , v) | format ! ("{}{:?}" , if i > 0 { ", " } else { "" } , v)) . collect :: < String > () , self . preshared_key . iter () . enumerate () . map (| (i , v) | format ! ("{}{:?}" , if i > 0 { ", " } else { "" } , v)) . collect :: < String > () , self . endpoint , self . last_handshake_time , self . rx_bytes , self . tx_bytes , self . persistent_keepalive_interval , self . first_allowedip , self . last_allowedip , self . next_peer) - } -} -impl wg_device_flags { - pub const WGDEVICE_REPLACE_PEERS: wg_device_flags = wg_device_flags(1); -} -impl wg_device_flags { - pub const WGDEVICE_HAS_PRIVATE_KEY: wg_device_flags = wg_device_flags(2); -} -impl wg_device_flags { - pub const WGDEVICE_HAS_PUBLIC_KEY: wg_device_flags = wg_device_flags(4); -} -impl wg_device_flags { - pub const WGDEVICE_HAS_LISTEN_PORT: wg_device_flags = wg_device_flags(8); -} -impl wg_device_flags { - pub const WGDEVICE_HAS_FWMARK: wg_device_flags = wg_device_flags(16); -} -impl ::std::ops::BitOr for wg_device_flags { - type Output = Self; - #[inline] - fn bitor(self, other: Self) -> Self { - wg_device_flags(self.0 | other.0) - } -} -impl ::std::ops::BitOrAssign for wg_device_flags { - #[inline] - fn bitor_assign(&mut self, rhs: wg_device_flags) { - self.0 |= rhs.0; - } -} -impl ::std::ops::BitAnd for wg_device_flags { - type Output = Self; - #[inline] - fn bitand(self, other: Self) -> Self { - wg_device_flags(self.0 & other.0) - } -} -impl ::std::ops::BitAndAssign for wg_device_flags { - #[inline] - fn bitand_assign(&mut self, rhs: wg_device_flags) { - self.0 &= rhs.0; - } -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct wg_device_flags(pub ::std::os::raw::c_uint); -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct wg_device { - pub name: [::std::os::raw::c_char; 16usize], - pub ifindex: u32, - pub flags: wg_device_flags, - pub public_key: wg_key, - pub private_key: wg_key, - pub fwmark: u32, - pub listen_port: u16, - pub first_peer: *mut wg_peer, - pub last_peer: *mut wg_peer, -} -#[test] -fn bindgen_test_layout_wg_device() { - assert_eq!( - ::std::mem::size_of::(), - 112usize, - concat!("Size of: ", stringify!(wg_device)) - ); - assert_eq!( - ::std::mem::align_of::(), - 8usize, - concat!("Alignment of ", stringify!(wg_device)) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).name as *const _ as usize }, - 0usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(name) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).ifindex as *const _ as usize }, - 16usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(ifindex) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).flags as *const _ as usize }, - 20usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(flags) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).public_key as *const _ as usize }, - 24usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(public_key) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).private_key as *const _ as usize }, - 56usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(private_key) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).fwmark as *const _ as usize }, - 88usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(fwmark) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).listen_port as *const _ as usize }, - 92usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(listen_port) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).first_peer as *const _ as usize }, - 96usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(first_peer) - ) - ); - assert_eq!( - unsafe { &(*(::std::ptr::null::())).last_peer as *const _ as usize }, - 104usize, - concat!( - "Offset of field: ", - stringify!(wg_device), - "::", - stringify!(last_peer) - ) - ); -} -impl Default for wg_device { - fn default() -> Self { - let mut s = ::std::mem::MaybeUninit::::uninit(); - unsafe { - ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); - s.assume_init() - } - } -} -extern "C" { - pub fn wg_set_device(dev: *mut wg_device) -> ::std::os::raw::c_int; -} -extern "C" { - pub fn wg_get_device( - dev: *mut *mut wg_device, - device_name: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -extern "C" { - pub fn wg_add_device(device_name: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int; -} -extern "C" { - pub fn wg_del_device(device_name: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int; -} -extern "C" { - pub fn wg_free_device(dev: *mut wg_device); -} -extern "C" { - pub fn wg_list_device_names() -> *mut ::std::os::raw::c_char; -} -extern "C" { - pub fn wg_key_to_base64(base64: *mut ::std::os::raw::c_char, key: *mut u8); -} -extern "C" { - pub fn wg_key_from_base64( - key: *mut u8, - base64: *mut ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -extern "C" { - pub fn wg_key_is_zero(key: *mut u8) -> bool; -} -extern "C" { - pub fn wg_generate_public_key(public_key: *mut u8, private_key: *mut u8); -} -extern "C" { - pub fn wg_generate_private_key(private_key: *mut u8); -} -extern "C" { - pub fn wg_generate_preshared_key(preshared_key: *mut u8); -} diff --git a/wireguard-control-sys/build.rs b/wireguard-control-sys/build.rs deleted file mode 100644 index 23d1cd4..0000000 --- a/wireguard-control-sys/build.rs +++ /dev/null @@ -1,49 +0,0 @@ -#[cfg(target_os = "linux")] -mod linux { - use std::{env, path::PathBuf}; - - pub fn build_bindings() { - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - - #[cfg(feature = "buildtime_bindgen")] - { - let bindings = bindgen::Builder::default() - .rust_target(bindgen::RustTarget::Stable_1_40) - .derive_default(true) - .header("c/wireguard.h") - .impl_debug(true) - .allowlist_function("wg_.*") - .bitfield_enum("wg_peer_flags") - .bitfield_enum("wg_device_flags"); - - let bindings = bindings.generate().expect("Unable to generate bindings"); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); - } - #[cfg(not(feature = "buildtime_bindgen"))] - { - std::fs::copy("bindgen-bindings/bindings.rs", out_path.join("bindings.rs")) - .expect("Could not copy bindings to output directory"); - } - } - - pub fn build_library() { - cc::Build::new() - .file("c/wireguard.c") - .warnings(true) - .extra_warnings(true) - .warnings_into_errors(true) - .flag_if_supported("-Wno-unused-parameter") - .compile("wireguard"); - } -} - -#[cfg(target_os = "linux")] -fn main() { - linux::build_bindings(); - linux::build_library(); -} - -#[cfg(not(target_os = "linux"))] -fn main() {} diff --git a/wireguard-control-sys/c/wireguard.c b/wireguard-control-sys/c/wireguard.c deleted file mode 100644 index 4941549..0000000 --- a/wireguard-control-sys/c/wireguard.c +++ /dev/null @@ -1,1755 +0,0 @@ -// SPDX-License-Identifier: LGPL-2.1+ -/* - * Copyright (C) 2015-2020 Jason A. Donenfeld . All Rights Reserved. - * Copyright (C) 2008-2012 Pablo Neira Ayuso . - */ - -#define _GNU_SOURCE - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "wireguard.h" - -/* wireguard.h netlink uapi: */ - -#define WG_GENL_NAME "wireguard" -#define WG_GENL_VERSION 1 - -enum wg_cmd { - WG_CMD_GET_DEVICE, - WG_CMD_SET_DEVICE, - __WG_CMD_MAX -}; - -enum wgdevice_flag { - WGDEVICE_F_REPLACE_PEERS = 1U << 0 -}; -enum wgdevice_attribute { - WGDEVICE_A_UNSPEC, - WGDEVICE_A_IFINDEX, - WGDEVICE_A_IFNAME, - WGDEVICE_A_PRIVATE_KEY, - WGDEVICE_A_PUBLIC_KEY, - WGDEVICE_A_FLAGS, - WGDEVICE_A_LISTEN_PORT, - WGDEVICE_A_FWMARK, - WGDEVICE_A_PEERS, - __WGDEVICE_A_LAST -}; - -enum wgpeer_flag { - WGPEER_F_REMOVE_ME = 1U << 0, - WGPEER_F_REPLACE_ALLOWEDIPS = 1U << 1 -}; -enum wgpeer_attribute { - WGPEER_A_UNSPEC, - WGPEER_A_PUBLIC_KEY, - WGPEER_A_PRESHARED_KEY, - WGPEER_A_FLAGS, - WGPEER_A_ENDPOINT, - WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, - WGPEER_A_LAST_HANDSHAKE_TIME, - WGPEER_A_RX_BYTES, - WGPEER_A_TX_BYTES, - WGPEER_A_ALLOWEDIPS, - WGPEER_A_PROTOCOL_VERSION, - __WGPEER_A_LAST -}; - -enum wgallowedip_attribute { - WGALLOWEDIP_A_UNSPEC, - WGALLOWEDIP_A_FAMILY, - WGALLOWEDIP_A_IPADDR, - WGALLOWEDIP_A_CIDR_MASK, - __WGALLOWEDIP_A_LAST -}; - -/* libmnl mini library: */ - -#define MNL_SOCKET_AUTOPID 0 -#define MNL_ALIGNTO 4 -#define MNL_ALIGN(len) (((len)+MNL_ALIGNTO-1) & ~(MNL_ALIGNTO-1)) -#define MNL_NLMSG_HDRLEN MNL_ALIGN(sizeof(struct nlmsghdr)) -#define MNL_ATTR_HDRLEN MNL_ALIGN(sizeof(struct nlattr)) - -enum mnl_attr_data_type { - MNL_TYPE_UNSPEC, - MNL_TYPE_U8, - MNL_TYPE_U16, - MNL_TYPE_U32, - MNL_TYPE_U64, - MNL_TYPE_STRING, - MNL_TYPE_FLAG, - MNL_TYPE_MSECS, - MNL_TYPE_NESTED, - MNL_TYPE_NESTED_COMPAT, - MNL_TYPE_NUL_STRING, - MNL_TYPE_BINARY, - MNL_TYPE_MAX, -}; - -#define mnl_attr_for_each(attr, nlh, offset) \ - for ((attr) = mnl_nlmsg_get_payload_offset((nlh), (offset)); \ - mnl_attr_ok((attr), (char *)mnl_nlmsg_get_payload_tail(nlh) - (char *)(attr)); \ - (attr) = mnl_attr_next(attr)) - -#define mnl_attr_for_each_nested(attr, nest) \ - for ((attr) = mnl_attr_get_payload(nest); \ - mnl_attr_ok((attr), (char *)mnl_attr_get_payload(nest) + mnl_attr_get_payload_len(nest) - (char *)(attr)); \ - (attr) = mnl_attr_next(attr)) - -#define mnl_attr_for_each_payload(payload, payload_size) \ - for ((attr) = (payload); \ - mnl_attr_ok((attr), (char *)(payload) + payload_size - (char *)(attr)); \ - (attr) = mnl_attr_next(attr)) - -#define MNL_CB_ERROR -1 -#define MNL_CB_STOP 0 -#define MNL_CB_OK 1 - -typedef int (*mnl_attr_cb_t)(const struct nlattr *attr, void *data); -typedef int (*mnl_cb_t)(const struct nlmsghdr *nlh, void *data); - -#ifndef MNL_ARRAY_SIZE -#define MNL_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0])) -#endif - -static size_t mnl_ideal_socket_buffer_size(void) -{ - static size_t size = 0; - - if (size) - return size; - size = (size_t)sysconf(_SC_PAGESIZE); - if (size > 8192) - size = 8192; - return size; -} - -static size_t mnl_nlmsg_size(size_t len) -{ - return len + MNL_NLMSG_HDRLEN; -} - -static struct nlmsghdr *mnl_nlmsg_put_header(void *buf) -{ - int len = MNL_ALIGN(sizeof(struct nlmsghdr)); - struct nlmsghdr *nlh = buf; - - memset(buf, 0, len); - nlh->nlmsg_len = len; - return nlh; -} - -static void *mnl_nlmsg_put_extra_header(struct nlmsghdr *nlh, size_t size) -{ - char *ptr = (char *)nlh + nlh->nlmsg_len; - size_t len = MNL_ALIGN(size); - nlh->nlmsg_len += len; - memset(ptr, 0, len); - return ptr; -} - -static void *mnl_nlmsg_get_payload(const struct nlmsghdr *nlh) -{ - return (void *)nlh + MNL_NLMSG_HDRLEN; -} - -static void *mnl_nlmsg_get_payload_offset(const struct nlmsghdr *nlh, size_t offset) -{ - return (void *)nlh + MNL_NLMSG_HDRLEN + MNL_ALIGN(offset); -} - -static bool mnl_nlmsg_ok(const struct nlmsghdr *nlh, int len) -{ - return len >= (int)sizeof(struct nlmsghdr) && - nlh->nlmsg_len >= sizeof(struct nlmsghdr) && - (int)nlh->nlmsg_len <= len; -} - -static struct nlmsghdr *mnl_nlmsg_next(const struct nlmsghdr *nlh, int *len) -{ - *len -= MNL_ALIGN(nlh->nlmsg_len); - return (struct nlmsghdr *)((void *)nlh + MNL_ALIGN(nlh->nlmsg_len)); -} - -static void *mnl_nlmsg_get_payload_tail(const struct nlmsghdr *nlh) -{ - return (void *)nlh + MNL_ALIGN(nlh->nlmsg_len); -} - -static bool mnl_nlmsg_seq_ok(const struct nlmsghdr *nlh, unsigned int seq) -{ - return nlh->nlmsg_seq && seq ? nlh->nlmsg_seq == seq : true; -} - -static bool mnl_nlmsg_portid_ok(const struct nlmsghdr *nlh, unsigned int portid) -{ - return nlh->nlmsg_pid && portid ? nlh->nlmsg_pid == portid : true; -} - -static uint16_t mnl_attr_get_type(const struct nlattr *attr) -{ - return attr->nla_type & NLA_TYPE_MASK; -} - -static uint16_t mnl_attr_get_payload_len(const struct nlattr *attr) -{ - return attr->nla_len - MNL_ATTR_HDRLEN; -} - -static void *mnl_attr_get_payload(const struct nlattr *attr) -{ - return (void *)attr + MNL_ATTR_HDRLEN; -} - -static bool mnl_attr_ok(const struct nlattr *attr, int len) -{ - return len >= (int)sizeof(struct nlattr) && - attr->nla_len >= sizeof(struct nlattr) && - (int)attr->nla_len <= len; -} - -static struct nlattr *mnl_attr_next(const struct nlattr *attr) -{ - return (struct nlattr *)((void *)attr + MNL_ALIGN(attr->nla_len)); -} - -static int mnl_attr_type_valid(const struct nlattr *attr, uint16_t max) -{ - if (mnl_attr_get_type(attr) > max) { - errno = EOPNOTSUPP; - return -1; - } - return 1; -} - -static int __mnl_attr_validate(const struct nlattr *attr, - enum mnl_attr_data_type type, size_t exp_len) -{ - uint16_t attr_len = mnl_attr_get_payload_len(attr); - const char *attr_data = mnl_attr_get_payload(attr); - - if (attr_len < exp_len) { - errno = ERANGE; - return -1; - } - switch(type) { - case MNL_TYPE_FLAG: - if (attr_len > 0) { - errno = ERANGE; - return -1; - } - break; - case MNL_TYPE_NUL_STRING: - if (attr_len == 0) { - errno = ERANGE; - return -1; - } - if (attr_data[attr_len-1] != '\0') { - errno = EINVAL; - return -1; - } - break; - case MNL_TYPE_STRING: - if (attr_len == 0) { - errno = ERANGE; - return -1; - } - break; - case MNL_TYPE_NESTED: - - if (attr_len == 0) - break; - - if (attr_len < MNL_ATTR_HDRLEN) { - errno = ERANGE; - return -1; - } - break; - default: - - break; - } - if (exp_len && attr_len > exp_len) { - errno = ERANGE; - return -1; - } - return 0; -} - -static const size_t mnl_attr_data_type_len[MNL_TYPE_MAX] = { - [MNL_TYPE_U8] = sizeof(uint8_t), - [MNL_TYPE_U16] = sizeof(uint16_t), - [MNL_TYPE_U32] = sizeof(uint32_t), - [MNL_TYPE_U64] = sizeof(uint64_t), - [MNL_TYPE_MSECS] = sizeof(uint64_t), -}; - -static int mnl_attr_validate(const struct nlattr *attr, enum mnl_attr_data_type type) -{ - int exp_len; - - if (type >= MNL_TYPE_MAX) { - errno = EINVAL; - return -1; - } - exp_len = mnl_attr_data_type_len[type]; - return __mnl_attr_validate(attr, type, exp_len); -} - -static int mnl_attr_parse(const struct nlmsghdr *nlh, unsigned int offset, - mnl_attr_cb_t cb, void *data) -{ - int ret = MNL_CB_OK; - const struct nlattr *attr; - - mnl_attr_for_each(attr, nlh, offset) - if ((ret = cb(attr, data)) <= MNL_CB_STOP) - return ret; - return ret; -} - -static int mnl_attr_parse_nested(const struct nlattr *nested, mnl_attr_cb_t cb, - void *data) -{ - int ret = MNL_CB_OK; - const struct nlattr *attr; - - mnl_attr_for_each_nested(attr, nested) - if ((ret = cb(attr, data)) <= MNL_CB_STOP) - return ret; - return ret; -} - -static uint8_t mnl_attr_get_u8(const struct nlattr *attr) -{ - return *((uint8_t *)mnl_attr_get_payload(attr)); -} - -static uint16_t mnl_attr_get_u16(const struct nlattr *attr) -{ - return *((uint16_t *)mnl_attr_get_payload(attr)); -} - -static uint32_t mnl_attr_get_u32(const struct nlattr *attr) -{ - return *((uint32_t *)mnl_attr_get_payload(attr)); -} - -static uint64_t mnl_attr_get_u64(const struct nlattr *attr) -{ - uint64_t tmp; - memcpy(&tmp, mnl_attr_get_payload(attr), sizeof(tmp)); - return tmp; -} - -static const char *mnl_attr_get_str(const struct nlattr *attr) -{ - return mnl_attr_get_payload(attr); -} - -static void mnl_attr_put(struct nlmsghdr *nlh, uint16_t type, size_t len, - const void *data) -{ - struct nlattr *attr = mnl_nlmsg_get_payload_tail(nlh); - uint16_t payload_len = MNL_ALIGN(sizeof(struct nlattr)) + len; - int pad; - - attr->nla_type = type; - attr->nla_len = payload_len; - memcpy(mnl_attr_get_payload(attr), data, len); - nlh->nlmsg_len += MNL_ALIGN(payload_len); - pad = MNL_ALIGN(len) - len; - if (pad > 0) - memset(mnl_attr_get_payload(attr) + len, 0, pad); -} - -static void mnl_attr_put_u16(struct nlmsghdr *nlh, uint16_t type, uint16_t data) -{ - mnl_attr_put(nlh, type, sizeof(uint16_t), &data); -} - -static void mnl_attr_put_u32(struct nlmsghdr *nlh, uint16_t type, uint32_t data) -{ - mnl_attr_put(nlh, type, sizeof(uint32_t), &data); -} - -static void mnl_attr_put_strz(struct nlmsghdr *nlh, uint16_t type, const char *data) -{ - mnl_attr_put(nlh, type, strlen(data)+1, data); -} - -static struct nlattr *mnl_attr_nest_start(struct nlmsghdr *nlh, uint16_t type) -{ - struct nlattr *start = mnl_nlmsg_get_payload_tail(nlh); - - start->nla_type = NLA_F_NESTED | type; - nlh->nlmsg_len += MNL_ALIGN(sizeof(struct nlattr)); - return start; -} - -static bool mnl_attr_put_check(struct nlmsghdr *nlh, size_t buflen, - uint16_t type, size_t len, const void *data) -{ - if (nlh->nlmsg_len + MNL_ATTR_HDRLEN + MNL_ALIGN(len) > buflen) - return false; - mnl_attr_put(nlh, type, len, data); - return true; -} - -static bool mnl_attr_put_u8_check(struct nlmsghdr *nlh, size_t buflen, - uint16_t type, uint8_t data) -{ - return mnl_attr_put_check(nlh, buflen, type, sizeof(uint8_t), &data); -} - -static bool mnl_attr_put_u16_check(struct nlmsghdr *nlh, size_t buflen, - uint16_t type, uint16_t data) -{ - return mnl_attr_put_check(nlh, buflen, type, sizeof(uint16_t), &data); -} - -static bool mnl_attr_put_u32_check(struct nlmsghdr *nlh, size_t buflen, - uint16_t type, uint32_t data) -{ - return mnl_attr_put_check(nlh, buflen, type, sizeof(uint32_t), &data); -} - -static struct nlattr *mnl_attr_nest_start_check(struct nlmsghdr *nlh, size_t buflen, - uint16_t type) -{ - if (nlh->nlmsg_len + MNL_ATTR_HDRLEN > buflen) - return NULL; - return mnl_attr_nest_start(nlh, type); -} - -static void mnl_attr_nest_end(struct nlmsghdr *nlh, struct nlattr *start) -{ - start->nla_len = mnl_nlmsg_get_payload_tail(nlh) - (void *)start; -} - -static void mnl_attr_nest_cancel(struct nlmsghdr *nlh, struct nlattr *start) -{ - nlh->nlmsg_len -= mnl_nlmsg_get_payload_tail(nlh) - (void *)start; -} - -static int mnl_cb_noop(__attribute__((unused)) const struct nlmsghdr *nlh, __attribute__((unused)) void *data) -{ - return MNL_CB_OK; -} - -static int mnl_cb_error(const struct nlmsghdr *nlh, __attribute__((unused)) void *data) -{ - const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh); - - if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) { - errno = EBADMSG; - return MNL_CB_ERROR; - } - - if (err->error < 0) - errno = -err->error; - else - errno = err->error; - - return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR; -} - -static int mnl_cb_stop(__attribute__((unused)) const struct nlmsghdr *nlh, __attribute__((unused)) void *data) -{ - return MNL_CB_STOP; -} - -static const mnl_cb_t default_cb_array[NLMSG_MIN_TYPE] = { - [NLMSG_NOOP] = mnl_cb_noop, - [NLMSG_ERROR] = mnl_cb_error, - [NLMSG_DONE] = mnl_cb_stop, - [NLMSG_OVERRUN] = mnl_cb_noop, -}; - -static int __mnl_cb_run(const void *buf, size_t numbytes, - unsigned int seq, unsigned int portid, - mnl_cb_t cb_data, void *data, - const mnl_cb_t *cb_ctl_array, - unsigned int cb_ctl_array_len) -{ - int ret = MNL_CB_OK, len = numbytes; - const struct nlmsghdr *nlh = buf; - - while (mnl_nlmsg_ok(nlh, len)) { - - if (!mnl_nlmsg_portid_ok(nlh, portid)) { - errno = ESRCH; - return -1; - } - - if (!mnl_nlmsg_seq_ok(nlh, seq)) { - errno = EPROTO; - return -1; - } - - if (nlh->nlmsg_flags & NLM_F_DUMP_INTR) { - errno = EINTR; - return -1; - } - - if (nlh->nlmsg_type >= NLMSG_MIN_TYPE) { - if (cb_data){ - ret = cb_data(nlh, data); - if (ret <= MNL_CB_STOP) - goto out; - } - } else if (nlh->nlmsg_type < cb_ctl_array_len) { - if (cb_ctl_array && cb_ctl_array[nlh->nlmsg_type]) { - ret = cb_ctl_array[nlh->nlmsg_type](nlh, data); - if (ret <= MNL_CB_STOP) - goto out; - } - } else if (default_cb_array[nlh->nlmsg_type]) { - ret = default_cb_array[nlh->nlmsg_type](nlh, data); - if (ret <= MNL_CB_STOP) - goto out; - } - nlh = mnl_nlmsg_next(nlh, &len); - } -out: - return ret; -} - -static int mnl_cb_run2(const void *buf, size_t numbytes, unsigned int seq, - unsigned int portid, mnl_cb_t cb_data, void *data, - const mnl_cb_t *cb_ctl_array, unsigned int cb_ctl_array_len) -{ - return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data, - cb_ctl_array, cb_ctl_array_len); -} - -static int mnl_cb_run(const void *buf, size_t numbytes, unsigned int seq, - unsigned int portid, mnl_cb_t cb_data, void *data) -{ - return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data, NULL, 0); -} - -struct mnl_socket { - int fd; - struct sockaddr_nl addr; -}; - -static unsigned int mnl_socket_get_portid(const struct mnl_socket *nl) -{ - return nl->addr.nl_pid; -} - -static struct mnl_socket *__mnl_socket_open(int bus, int flags) -{ - struct mnl_socket *nl; - - nl = calloc(1, sizeof(struct mnl_socket)); - if (nl == NULL) - return NULL; - - nl->fd = socket(AF_NETLINK, SOCK_RAW | flags, bus); - if (nl->fd == -1) { - free(nl); - return NULL; - } - - return nl; -} - -static struct mnl_socket *mnl_socket_open(int bus) -{ - return __mnl_socket_open(bus, 0); -} - -static int mnl_socket_bind(struct mnl_socket *nl, unsigned int groups, pid_t pid) -{ - int ret; - socklen_t addr_len; - - nl->addr.nl_family = AF_NETLINK; - nl->addr.nl_groups = groups; - nl->addr.nl_pid = pid; - - ret = bind(nl->fd, (struct sockaddr *) &nl->addr, sizeof (nl->addr)); - if (ret < 0) - return ret; - - addr_len = sizeof(nl->addr); - ret = getsockname(nl->fd, (struct sockaddr *) &nl->addr, &addr_len); - if (ret < 0) - return ret; - - if (addr_len != sizeof(nl->addr)) { - errno = EINVAL; - return -1; - } - if (nl->addr.nl_family != AF_NETLINK) { - errno = EINVAL; - return -1; - } - return 0; -} - -static ssize_t mnl_socket_sendto(const struct mnl_socket *nl, const void *buf, - size_t len) -{ - static const struct sockaddr_nl snl = { - .nl_family = AF_NETLINK - }; - return sendto(nl->fd, buf, len, 0, - (struct sockaddr *) &snl, sizeof(snl)); -} - -static ssize_t mnl_socket_recvfrom(const struct mnl_socket *nl, void *buf, - size_t bufsiz) -{ - ssize_t ret; - struct sockaddr_nl addr; - struct iovec iov = { - .iov_base = buf, - .iov_len = bufsiz, - }; - struct msghdr msg = { - .msg_name = &addr, - .msg_namelen = sizeof(struct sockaddr_nl), - .msg_iov = &iov, - .msg_iovlen = 1, - .msg_control = NULL, - .msg_controllen = 0, - .msg_flags = 0, - }; - ret = recvmsg(nl->fd, &msg, 0); - if (ret == -1) - return ret; - - if (msg.msg_flags & MSG_TRUNC) { - errno = ENOSPC; - return -1; - } - if (msg.msg_namelen != sizeof(struct sockaddr_nl)) { - errno = EINVAL; - return -1; - } - return ret; -} - -static int mnl_socket_close(struct mnl_socket *nl) -{ - int ret = close(nl->fd); - free(nl); - return ret; -} - -/* mnlg mini library: */ - -struct mnlg_socket { - struct mnl_socket *nl; - char *buf; - uint16_t id; - uint8_t version; - unsigned int seq; - unsigned int portid; -}; - -static struct nlmsghdr *__mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd, - uint16_t flags, uint16_t id, - uint8_t version) -{ - struct nlmsghdr *nlh; - struct genlmsghdr *genl; - - nlh = mnl_nlmsg_put_header(nlg->buf); - nlh->nlmsg_type = id; - nlh->nlmsg_flags = flags; - nlg->seq = time(NULL); - nlh->nlmsg_seq = nlg->seq; - - genl = mnl_nlmsg_put_extra_header(nlh, sizeof(struct genlmsghdr)); - genl->cmd = cmd; - genl->version = version; - - return nlh; -} - -static struct nlmsghdr *mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd, - uint16_t flags) -{ - return __mnlg_msg_prepare(nlg, cmd, flags, nlg->id, nlg->version); -} - -static int mnlg_socket_send(struct mnlg_socket *nlg, const struct nlmsghdr *nlh) -{ - return mnl_socket_sendto(nlg->nl, nlh, nlh->nlmsg_len); -} - -static int mnlg_cb_noop(const struct nlmsghdr *nlh, void *data) -{ - (void)nlh; - (void)data; - return MNL_CB_OK; -} - -static int mnlg_cb_error(const struct nlmsghdr *nlh, void *data) -{ - const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh); - (void)data; - - if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) { - errno = EBADMSG; - return MNL_CB_ERROR; - } - /* Netlink subsystems returns the errno value with different signess */ - if (err->error < 0) - errno = -err->error; - else - errno = err->error; - - return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR; -} - -static int mnlg_cb_stop(const struct nlmsghdr *nlh, void *data) -{ - (void)data; - if (nlh->nlmsg_flags & NLM_F_MULTI && nlh->nlmsg_len == mnl_nlmsg_size(sizeof(int))) { - int error = *(int *)mnl_nlmsg_get_payload(nlh); - /* Netlink subsystems returns the errno value with different signess */ - if (error < 0) - errno = -error; - else - errno = error; - - return error == 0 ? MNL_CB_STOP : MNL_CB_ERROR; - } - return MNL_CB_STOP; -} - -static const mnl_cb_t mnlg_cb_array[] = { - [NLMSG_NOOP] = mnlg_cb_noop, - [NLMSG_ERROR] = mnlg_cb_error, - [NLMSG_DONE] = mnlg_cb_stop, - [NLMSG_OVERRUN] = mnlg_cb_noop, -}; - -static int mnlg_socket_recv_run(struct mnlg_socket *nlg, mnl_cb_t data_cb, void *data) -{ - int err; - - do { - err = mnl_socket_recvfrom(nlg->nl, nlg->buf, - mnl_ideal_socket_buffer_size()); - if (err <= 0) - break; - err = mnl_cb_run2(nlg->buf, err, nlg->seq, nlg->portid, - data_cb, data, mnlg_cb_array, MNL_ARRAY_SIZE(mnlg_cb_array)); - } while (err > 0); - - return err; -} - -static int get_family_id_attr_cb(const struct nlattr *attr, void *data) -{ - const struct nlattr **tb = data; - int type = mnl_attr_get_type(attr); - - if (mnl_attr_type_valid(attr, CTRL_ATTR_MAX) < 0) - return MNL_CB_ERROR; - - if (type == CTRL_ATTR_FAMILY_ID && - mnl_attr_validate(attr, MNL_TYPE_U16) < 0) - return MNL_CB_ERROR; - tb[type] = attr; - return MNL_CB_OK; -} - -static int get_family_id_cb(const struct nlmsghdr *nlh, void *data) -{ - uint16_t *p_id = data; - struct nlattr *tb[CTRL_ATTR_MAX + 1] = { 0 }; - - mnl_attr_parse(nlh, sizeof(struct genlmsghdr), get_family_id_attr_cb, tb); - if (!tb[CTRL_ATTR_FAMILY_ID]) - return MNL_CB_ERROR; - *p_id = mnl_attr_get_u16(tb[CTRL_ATTR_FAMILY_ID]); - return MNL_CB_OK; -} - -static struct mnlg_socket *mnlg_socket_open(const char *family_name, uint8_t version) -{ - struct mnlg_socket *nlg; - struct nlmsghdr *nlh; - int err; - - nlg = malloc(sizeof(*nlg)); - if (!nlg) - return NULL; - nlg->id = 0; - - err = -ENOMEM; - nlg->buf = malloc(mnl_ideal_socket_buffer_size()); - if (!nlg->buf) - goto err_buf_alloc; - - nlg->nl = mnl_socket_open(NETLINK_GENERIC); - if (!nlg->nl) { - err = -errno; - goto err_mnl_socket_open; - } - - if (mnl_socket_bind(nlg->nl, 0, MNL_SOCKET_AUTOPID) < 0) { - err = -errno; - goto err_mnl_socket_bind; - } - - nlg->portid = mnl_socket_get_portid(nlg->nl); - - nlh = __mnlg_msg_prepare(nlg, CTRL_CMD_GETFAMILY, - NLM_F_REQUEST | NLM_F_ACK, GENL_ID_CTRL, 1); - mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name); - - if (mnlg_socket_send(nlg, nlh) < 0) { - err = -errno; - goto err_mnlg_socket_send; - } - - errno = 0; - if (mnlg_socket_recv_run(nlg, get_family_id_cb, &nlg->id) < 0) { - errno = errno == ENOENT ? EPROTONOSUPPORT : errno; - err = errno ? -errno : -ENOSYS; - goto err_mnlg_socket_recv_run; - } - - nlg->version = version; - errno = 0; - return nlg; - -err_mnlg_socket_recv_run: -err_mnlg_socket_send: -err_mnl_socket_bind: - mnl_socket_close(nlg->nl); -err_mnl_socket_open: - free(nlg->buf); -err_buf_alloc: - free(nlg); - errno = -err; - return NULL; -} - -static void mnlg_socket_close(struct mnlg_socket *nlg) -{ - mnl_socket_close(nlg->nl); - free(nlg->buf); - free(nlg); -} - -/* wireguard-specific parts: */ - -struct string_list { - char *buffer; - size_t len; - size_t cap; -}; - -static int string_list_add(struct string_list *list, const char *str) -{ - size_t len = strlen(str) + 1; - - if (len == 1) - return 0; - - if (len >= list->cap - list->len) { - char *new_buffer; - size_t new_cap = list->cap * 2; - - if (new_cap < list->len +len + 1) - new_cap = list->len + len + 1; - new_buffer = realloc(list->buffer, new_cap); - if (!new_buffer) - return -errno; - list->buffer = new_buffer; - list->cap = new_cap; - } - memcpy(list->buffer + list->len, str, len); - list->len += len; - list->buffer[list->len] = '\0'; - return 0; -} - -struct interface { - const char *name; - bool is_wireguard; -}; - -static int parse_linkinfo(const struct nlattr *attr, void *data) -{ - struct interface *interface = data; - - if (mnl_attr_get_type(attr) == IFLA_INFO_KIND && !strcmp(WG_GENL_NAME, mnl_attr_get_str(attr))) - interface->is_wireguard = true; - return MNL_CB_OK; -} - -static int parse_infomsg(const struct nlattr *attr, void *data) -{ - struct interface *interface = data; - - if (mnl_attr_get_type(attr) == IFLA_LINKINFO) - return mnl_attr_parse_nested(attr, parse_linkinfo, data); - else if (mnl_attr_get_type(attr) == IFLA_IFNAME) - interface->name = mnl_attr_get_str(attr); - return MNL_CB_OK; -} - -static int read_devices_cb(const struct nlmsghdr *nlh, void *data) -{ - struct string_list *list = data; - struct interface interface = { 0 }; - int ret; - - ret = mnl_attr_parse(nlh, sizeof(struct ifinfomsg), parse_infomsg, &interface); - if (ret != MNL_CB_OK) - return ret; - if (interface.name && interface.is_wireguard) - ret = string_list_add(list, interface.name); - if (ret < 0) - return ret; - if (nlh->nlmsg_type != NLMSG_DONE) - return MNL_CB_OK + 1; - return MNL_CB_OK; -} - -static int fetch_device_names(struct string_list *list) -{ - struct mnl_socket *nl = NULL; - char *rtnl_buffer = NULL; - size_t message_len; - unsigned int portid, seq; - ssize_t len; - int ret = 0; - struct nlmsghdr *nlh; - struct ifinfomsg *ifm; - - ret = -ENOMEM; - rtnl_buffer = calloc(mnl_ideal_socket_buffer_size(), 1); - if (!rtnl_buffer) - goto cleanup; - - nl = mnl_socket_open(NETLINK_ROUTE); - if (!nl) { - ret = -errno; - goto cleanup; - } - - if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) { - ret = -errno; - goto cleanup; - } - - seq = time(NULL); - portid = mnl_socket_get_portid(nl); - nlh = mnl_nlmsg_put_header(rtnl_buffer); - nlh->nlmsg_type = RTM_GETLINK; - nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; - nlh->nlmsg_seq = seq; - ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm)); - ifm->ifi_family = AF_UNSPEC; - message_len = nlh->nlmsg_len; - - if (mnl_socket_sendto(nl, rtnl_buffer, message_len) < 0) { - ret = -errno; - goto cleanup; - } - -another: - if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, mnl_ideal_socket_buffer_size())) < 0) { - ret = -errno; - goto cleanup; - } - if ((len = mnl_cb_run(rtnl_buffer, len, seq, portid, read_devices_cb, list)) < 0) { - /* Netlink returns NLM_F_DUMP_INTR if the set of all tunnels changed - * during the dump. That's unfortunate, but is pretty common on busy - * systems that are adding and removing tunnels all the time. Rather - * than retrying, potentially indefinitely, we just work with the - * partial results. */ - if (errno != EINTR) { - ret = -errno; - goto cleanup; - } - } - if (len == MNL_CB_OK + 1) - goto another; - ret = 0; - -cleanup: - free(rtnl_buffer); - if (nl) - mnl_socket_close(nl); - return ret; -} - -static int add_del_iface(const char *ifname, bool add) -{ - struct mnl_socket *nl = NULL; - char *rtnl_buffer; - ssize_t len; - int ret; - struct nlmsghdr *nlh; - struct ifinfomsg *ifm; - struct nlattr *nest; - - rtnl_buffer = calloc(mnl_ideal_socket_buffer_size(), 1); - if (!rtnl_buffer) { - ret = -ENOMEM; - goto cleanup; - } - - nl = mnl_socket_open(NETLINK_ROUTE); - if (!nl) { - ret = -errno; - goto cleanup; - } - - if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) { - ret = -errno; - goto cleanup; - } - - nlh = mnl_nlmsg_put_header(rtnl_buffer); - nlh->nlmsg_type = add ? RTM_NEWLINK : RTM_DELLINK; - nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | (add ? NLM_F_CREATE | NLM_F_EXCL : 0); - nlh->nlmsg_seq = time(NULL); - ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm)); - ifm->ifi_family = AF_UNSPEC; - mnl_attr_put_strz(nlh, IFLA_IFNAME, ifname); - nest = mnl_attr_nest_start(nlh, IFLA_LINKINFO); - mnl_attr_put_strz(nlh, IFLA_INFO_KIND, WG_GENL_NAME); - mnl_attr_nest_end(nlh, nest); - - if (mnl_socket_sendto(nl, rtnl_buffer, nlh->nlmsg_len) < 0) { - ret = -errno; - goto cleanup; - } - if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, mnl_ideal_socket_buffer_size())) < 0) { - ret = -errno; - goto cleanup; - } - if (mnl_cb_run(rtnl_buffer, len, nlh->nlmsg_seq, mnl_socket_get_portid(nl), NULL, NULL) < 0) { - ret = -errno; - goto cleanup; - } - ret = 0; - -cleanup: - free(rtnl_buffer); - if (nl) - mnl_socket_close(nl); - return ret; -} - -int wg_set_device(wg_device *dev) -{ - int ret = 0; - wg_peer *peer = NULL; - wg_allowedip *allowedip = NULL; - struct nlattr *peers_nest, *peer_nest, *allowedips_nest, *allowedip_nest; - struct nlmsghdr *nlh; - struct mnlg_socket *nlg; - - nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION); - if (!nlg) - return -errno; - -again: - nlh = mnlg_msg_prepare(nlg, WG_CMD_SET_DEVICE, NLM_F_REQUEST | NLM_F_ACK); - mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, dev->name); - - if (!peer) { - uint32_t flags = 0; - - if (dev->flags & WGDEVICE_HAS_PRIVATE_KEY) - mnl_attr_put(nlh, WGDEVICE_A_PRIVATE_KEY, sizeof(dev->private_key), dev->private_key); - if (dev->flags & WGDEVICE_HAS_LISTEN_PORT) - mnl_attr_put_u16(nlh, WGDEVICE_A_LISTEN_PORT, dev->listen_port); - if (dev->flags & WGDEVICE_HAS_FWMARK) - mnl_attr_put_u32(nlh, WGDEVICE_A_FWMARK, dev->fwmark); - if (dev->flags & WGDEVICE_REPLACE_PEERS) - flags |= WGDEVICE_F_REPLACE_PEERS; - if (flags) - mnl_attr_put_u32(nlh, WGDEVICE_A_FLAGS, flags); - } - if (!dev->first_peer) - goto send; - peers_nest = peer_nest = allowedips_nest = allowedip_nest = NULL; - peers_nest = mnl_attr_nest_start(nlh, WGDEVICE_A_PEERS); - for (peer = peer ? peer : dev->first_peer; peer; peer = peer->next_peer) { - uint32_t flags = 0; - - peer_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), 0); - if (!peer_nest) - goto toobig_peers; - if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PUBLIC_KEY, sizeof(peer->public_key), peer->public_key)) - goto toobig_peers; - if (peer->flags & WGPEER_REMOVE_ME) - flags |= WGPEER_F_REMOVE_ME; - if (!allowedip) { - if (peer->flags & WGPEER_REPLACE_ALLOWEDIPS) - flags |= WGPEER_F_REPLACE_ALLOWEDIPS; - if (peer->flags & WGPEER_HAS_PRESHARED_KEY) { - if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PRESHARED_KEY, sizeof(peer->preshared_key), peer->preshared_key)) - goto toobig_peers; - } - if (peer->endpoint.addr.sa_family == AF_INET) { - if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr4), &peer->endpoint.addr4)) - goto toobig_peers; - } else if (peer->endpoint.addr.sa_family == AF_INET6) { - if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr6), &peer->endpoint.addr6)) - goto toobig_peers; - } - if (peer->flags & WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL) { - if (!mnl_attr_put_u16_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval)) - goto toobig_peers; - } - } - if (flags) { - if (!mnl_attr_put_u32_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_FLAGS, flags)) - goto toobig_peers; - } - if (peer->first_allowedip) { - if (!allowedip) - allowedip = peer->first_allowedip; - allowedips_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ALLOWEDIPS); - if (!allowedips_nest) - goto toobig_allowedips; - for (; allowedip; allowedip = allowedip->next_allowedip) { - allowedip_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), 0); - if (!allowedip_nest) - goto toobig_allowedips; - if (!mnl_attr_put_u16_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_FAMILY, allowedip->family)) - goto toobig_allowedips; - if (allowedip->family == AF_INET) { - if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip4), &allowedip->ip4)) - goto toobig_allowedips; - } else if (allowedip->family == AF_INET6) { - if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip6), &allowedip->ip6)) - goto toobig_allowedips; - } - if (!mnl_attr_put_u8_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_CIDR_MASK, allowedip->cidr)) - goto toobig_allowedips; - mnl_attr_nest_end(nlh, allowedip_nest); - allowedip_nest = NULL; - } - mnl_attr_nest_end(nlh, allowedips_nest); - allowedips_nest = NULL; - } - - mnl_attr_nest_end(nlh, peer_nest); - peer_nest = NULL; - } - mnl_attr_nest_end(nlh, peers_nest); - peers_nest = NULL; - goto send; -toobig_allowedips: - if (allowedip_nest) - mnl_attr_nest_cancel(nlh, allowedip_nest); - if (allowedips_nest) - mnl_attr_nest_end(nlh, allowedips_nest); - mnl_attr_nest_end(nlh, peer_nest); - mnl_attr_nest_end(nlh, peers_nest); - goto send; -toobig_peers: - if (peer_nest) - mnl_attr_nest_cancel(nlh, peer_nest); - mnl_attr_nest_end(nlh, peers_nest); - goto send; -send: - if (mnlg_socket_send(nlg, nlh) < 0) { - ret = -errno; - goto out; - } - errno = 0; - if (mnlg_socket_recv_run(nlg, NULL, NULL) < 0) { - ret = errno ? -errno : -EINVAL; - goto out; - } - if (peer) - goto again; - -out: - mnlg_socket_close(nlg); - errno = -ret; - return ret; -} - -static int parse_allowedip(const struct nlattr *attr, void *data) -{ - wg_allowedip *allowedip = data; - - switch (mnl_attr_get_type(attr)) { - case WGALLOWEDIP_A_UNSPEC: - break; - case WGALLOWEDIP_A_FAMILY: - if (!mnl_attr_validate(attr, MNL_TYPE_U16)) - allowedip->family = mnl_attr_get_u16(attr); - break; - case WGALLOWEDIP_A_IPADDR: - if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4)) - memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4)); - else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6)) - memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6)); - break; - case WGALLOWEDIP_A_CIDR_MASK: - if (!mnl_attr_validate(attr, MNL_TYPE_U8)) - allowedip->cidr = mnl_attr_get_u8(attr); - break; - } - - return MNL_CB_OK; -} - -static int parse_allowedips(const struct nlattr *attr, void *data) -{ - wg_peer *peer = data; - wg_allowedip *new_allowedip = calloc(1, sizeof(wg_allowedip)); - int ret; - - if (!new_allowedip) - return MNL_CB_ERROR; - if (!peer->first_allowedip) - peer->first_allowedip = peer->last_allowedip = new_allowedip; - else { - peer->last_allowedip->next_allowedip = new_allowedip; - peer->last_allowedip = new_allowedip; - } - ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip); - if (!ret) - return ret; - if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128))) { - errno = EAFNOSUPPORT; - return MNL_CB_ERROR; - } - return MNL_CB_OK; -} - -bool wg_key_is_zero(const wg_key key) -{ - volatile uint8_t acc = 0; - unsigned int i; - - for (i = 0; i < sizeof(wg_key); ++i) { - acc |= key[i]; - __asm__ ("" : "=r" (acc) : "0" (acc)); - } - return 1 & ((acc - 1) >> 8); -} - -static int parse_peer(const struct nlattr *attr, void *data) -{ - wg_peer *peer = data; - - switch (mnl_attr_get_type(attr)) { - case WGPEER_A_UNSPEC: - break; - case WGPEER_A_PUBLIC_KEY: - if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key)) { - memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key)); - peer->flags |= WGPEER_HAS_PUBLIC_KEY; - } - break; - case WGPEER_A_PRESHARED_KEY: - if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key)) { - memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key)); - if (!wg_key_is_zero(peer->preshared_key)) - peer->flags |= WGPEER_HAS_PRESHARED_KEY; - } - break; - case WGPEER_A_ENDPOINT: { - struct sockaddr *addr; - - if (mnl_attr_get_payload_len(attr) < sizeof(*addr)) - break; - addr = mnl_attr_get_payload(attr); - if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4)) - memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4)); - else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6)) - memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6)); - break; - } - case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: - if (!mnl_attr_validate(attr, MNL_TYPE_U16)) - peer->persistent_keepalive_interval = mnl_attr_get_u16(attr); - break; - case WGPEER_A_LAST_HANDSHAKE_TIME: - if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time)) - memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time)); - break; - case WGPEER_A_RX_BYTES: - if (!mnl_attr_validate(attr, MNL_TYPE_U64)) - peer->rx_bytes = mnl_attr_get_u64(attr); - break; - case WGPEER_A_TX_BYTES: - if (!mnl_attr_validate(attr, MNL_TYPE_U64)) - peer->tx_bytes = mnl_attr_get_u64(attr); - break; - case WGPEER_A_ALLOWEDIPS: - return mnl_attr_parse_nested(attr, parse_allowedips, peer); - } - - return MNL_CB_OK; -} - -static int parse_peers(const struct nlattr *attr, void *data) -{ - wg_device *device = data; - wg_peer *new_peer = calloc(1, sizeof(wg_peer)); - int ret; - - if (!new_peer) - return MNL_CB_ERROR; - if (!device->first_peer) - device->first_peer = device->last_peer = new_peer; - else { - device->last_peer->next_peer = new_peer; - device->last_peer = new_peer; - } - ret = mnl_attr_parse_nested(attr, parse_peer, new_peer); - if (!ret) - return ret; - if (!(new_peer->flags & WGPEER_HAS_PUBLIC_KEY)) { - errno = ENXIO; - return MNL_CB_ERROR; - } - return MNL_CB_OK; -} - -static int parse_device(const struct nlattr *attr, void *data) -{ - wg_device *device = data; - - switch (mnl_attr_get_type(attr)) { - case WGDEVICE_A_UNSPEC: - break; - case WGDEVICE_A_IFINDEX: - if (!mnl_attr_validate(attr, MNL_TYPE_U32)) - device->ifindex = mnl_attr_get_u32(attr); - break; - case WGDEVICE_A_IFNAME: - if (!mnl_attr_validate(attr, MNL_TYPE_STRING)) { - strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1); - device->name[sizeof(device->name) - 1] = '\0'; - } - break; - case WGDEVICE_A_PRIVATE_KEY: - if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key)) { - memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key)); - device->flags |= WGDEVICE_HAS_PRIVATE_KEY; - } - break; - case WGDEVICE_A_PUBLIC_KEY: - if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key)) { - memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key)); - device->flags |= WGDEVICE_HAS_PUBLIC_KEY; - } - break; - case WGDEVICE_A_LISTEN_PORT: - if (!mnl_attr_validate(attr, MNL_TYPE_U16)) - device->listen_port = mnl_attr_get_u16(attr); - break; - case WGDEVICE_A_FWMARK: - if (!mnl_attr_validate(attr, MNL_TYPE_U32)) - device->fwmark = mnl_attr_get_u32(attr); - break; - case WGDEVICE_A_PEERS: - return mnl_attr_parse_nested(attr, parse_peers, device); - } - - return MNL_CB_OK; -} - -static int read_device_cb(const struct nlmsghdr *nlh, void *data) -{ - return mnl_attr_parse(nlh, sizeof(struct genlmsghdr), parse_device, data); -} - -static void coalesce_peers(wg_device *device) -{ - wg_peer *old_next_peer, *peer = device->first_peer; - - while (peer && peer->next_peer) { - if (memcmp(peer->public_key, peer->next_peer->public_key, sizeof(wg_key))) { - peer = peer->next_peer; - continue; - } - if (!peer->first_allowedip) { - peer->first_allowedip = peer->next_peer->first_allowedip; - peer->last_allowedip = peer->next_peer->last_allowedip; - } else { - peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip; - peer->last_allowedip = peer->next_peer->last_allowedip; - } - old_next_peer = peer->next_peer; - peer->next_peer = old_next_peer->next_peer; - free(old_next_peer); - } -} - -int wg_get_device(wg_device **device, const char *device_name) -{ - int ret = 0; - struct nlmsghdr *nlh; - struct mnlg_socket *nlg; - -try_again: - *device = calloc(1, sizeof(wg_device)); - if (!*device) - return -errno; - - nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION); - if (!nlg) { - wg_free_device(*device); - *device = NULL; - return -errno; - } - - nlh = mnlg_msg_prepare(nlg, WG_CMD_GET_DEVICE, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP); - mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, device_name); - if (mnlg_socket_send(nlg, nlh) < 0) { - ret = -errno; - goto out; - } - errno = 0; - if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) { - ret = errno ? -errno : -EINVAL; - goto out; - } - coalesce_peers(*device); - -out: - if (nlg) - mnlg_socket_close(nlg); - if (ret) { - wg_free_device(*device); - if (ret == -EINTR) - goto try_again; - *device = NULL; - } - errno = -ret; - return ret; -} - -/* first\0second\0third\0forth\0last\0\0 */ -char *wg_list_device_names(void) -{ - struct string_list list = { 0 }; - int ret = fetch_device_names(&list); - - errno = -ret; - if (errno) { - free(list.buffer); - return NULL; - } - return list.buffer ?: strdup("\0"); -} - -int wg_add_device(const char *device_name) -{ - return add_del_iface(device_name, true); -} - -int wg_del_device(const char *device_name) -{ - return add_del_iface(device_name, false); -} - -void wg_free_device(wg_device *dev) -{ - wg_peer *peer, *np; - wg_allowedip *allowedip, *na; - - if (!dev) - return; - for (peer = dev->first_peer, np = peer ? peer->next_peer : NULL; peer; peer = np, np = peer ? peer->next_peer : NULL) { - for (allowedip = peer->first_allowedip, na = allowedip ? allowedip->next_allowedip : NULL; allowedip; allowedip = na, na = allowedip ? allowedip->next_allowedip : NULL) - free(allowedip); - free(peer); - } - free(dev); -} - -static void encode_base64(char dest[static 4], const uint8_t src[static 3]) -{ - const uint8_t input[] = { (src[0] >> 2) & 63, ((src[0] << 4) | (src[1] >> 4)) & 63, ((src[1] << 2) | (src[2] >> 6)) & 63, src[2] & 63 }; - unsigned int i; - - for (i = 0; i < 4; ++i) - dest[i] = input[i] + 'A' - + (((25 - input[i]) >> 8) & 6) - - (((51 - input[i]) >> 8) & 75) - - (((61 - input[i]) >> 8) & 15) - + (((62 - input[i]) >> 8) & 3); - -} - -void wg_key_to_base64(wg_key_b64_string base64, const wg_key key) -{ - unsigned int i; - - for (i = 0; i < 32 / 3; ++i) - encode_base64(&base64[i * 4], &key[i * 3]); - encode_base64(&base64[i * 4], (const uint8_t[]){ key[i * 3 + 0], key[i * 3 + 1], 0 }); - base64[sizeof(wg_key_b64_string) - 2] = '='; - base64[sizeof(wg_key_b64_string) - 1] = '\0'; -} - -static int decode_base64(const char src[static 4]) -{ - int val = 0; - unsigned int i; - - for (i = 0; i < 4; ++i) - val |= (-1 - + ((((('A' - 1) - src[i]) & (src[i] - ('Z' + 1))) >> 8) & (src[i] - 64)) - + ((((('a' - 1) - src[i]) & (src[i] - ('z' + 1))) >> 8) & (src[i] - 70)) - + ((((('0' - 1) - src[i]) & (src[i] - ('9' + 1))) >> 8) & (src[i] + 5)) - + ((((('+' - 1) - src[i]) & (src[i] - ('+' + 1))) >> 8) & 63) - + ((((('/' - 1) - src[i]) & (src[i] - ('/' + 1))) >> 8) & 64) - ) << (18 - 6 * i); - return val; -} - -int wg_key_from_base64(wg_key key, const wg_key_b64_string base64) -{ - unsigned int i; - int val; - volatile uint8_t ret = 0; - - if (strlen(base64) != sizeof(wg_key_b64_string) - 1 || base64[sizeof(wg_key_b64_string) - 2] != '=') { - errno = EINVAL; - goto out; - } - - for (i = 0; i < 32 / 3; ++i) { - val = decode_base64(&base64[i * 4]); - ret |= (uint32_t)val >> 31; - key[i * 3 + 0] = (val >> 16) & 0xff; - key[i * 3 + 1] = (val >> 8) & 0xff; - key[i * 3 + 2] = val & 0xff; - } - val = decode_base64((const char[]){ base64[i * 4 + 0], base64[i * 4 + 1], base64[i * 4 + 2], 'A' }); - ret |= ((uint32_t)val >> 31) | (val & 0xff); - key[i * 3 + 0] = (val >> 16) & 0xff; - key[i * 3 + 1] = (val >> 8) & 0xff; - errno = EINVAL & ~((ret - 1) >> 8); -out: - return -errno; -} - -typedef int64_t fe[16]; - -static __attribute__((noinline)) void memzero_explicit(void *s, size_t count) -{ - memset(s, 0, count); - __asm__ __volatile__("": :"r"(s) :"memory"); -} - -static void carry(fe o) -{ - int i; - - for (i = 0; i < 16; ++i) { - o[(i + 1) % 16] += (i == 15 ? 38 : 1) * (o[i] >> 16); - o[i] &= 0xffff; - } -} - -static void cswap(fe p, fe q, int b) -{ - int i; - int64_t t, c = ~(b - 1); - - for (i = 0; i < 16; ++i) { - t = c & (p[i] ^ q[i]); - p[i] ^= t; - q[i] ^= t; - } - - memzero_explicit(&t, sizeof(t)); - memzero_explicit(&c, sizeof(c)); - memzero_explicit(&b, sizeof(b)); -} - -static void pack(uint8_t *o, const fe n) -{ - int i, j, b; - fe m, t; - - memcpy(t, n, sizeof(t)); - carry(t); - carry(t); - carry(t); - for (j = 0; j < 2; ++j) { - m[0] = t[0] - 0xffed; - for (i = 1; i < 15; ++i) { - m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1); - m[i - 1] &= 0xffff; - } - m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1); - b = (m[15] >> 16) & 1; - m[14] &= 0xffff; - cswap(t, m, 1 - b); - } - for (i = 0; i < 16; ++i) { - o[2 * i] = t[i] & 0xff; - o[2 * i + 1] = t[i] >> 8; - } - - memzero_explicit(m, sizeof(m)); - memzero_explicit(t, sizeof(t)); - memzero_explicit(&b, sizeof(b)); -} - -static void add(fe o, const fe a, const fe b) -{ - int i; - - for (i = 0; i < 16; ++i) - o[i] = a[i] + b[i]; -} - -static void subtract(fe o, const fe a, const fe b) -{ - int i; - - for (i = 0; i < 16; ++i) - o[i] = a[i] - b[i]; -} - -static void multmod(fe o, const fe a, const fe b) -{ - int i, j; - int64_t t[31] = { 0 }; - - for (i = 0; i < 16; ++i) { - for (j = 0; j < 16; ++j) - t[i + j] += a[i] * b[j]; - } - for (i = 0; i < 15; ++i) - t[i] += 38 * t[i + 16]; - memcpy(o, t, sizeof(fe)); - carry(o); - carry(o); - - memzero_explicit(t, sizeof(t)); -} - -static void invert(fe o, const fe i) -{ - fe c; - int a; - - memcpy(c, i, sizeof(c)); - for (a = 253; a >= 0; --a) { - multmod(c, c, c); - if (a != 2 && a != 4) - multmod(c, c, i); - } - memcpy(o, c, sizeof(fe)); - - memzero_explicit(c, sizeof(c)); -} - -static void clamp_key(uint8_t *z) -{ - z[31] = (z[31] & 127) | 64; - z[0] &= 248; -} - -void wg_generate_public_key(wg_key public_key, const wg_key private_key) -{ - int i, r; - uint8_t z[32]; - fe a = { 1 }, b = { 9 }, c = { 0 }, d = { 1 }, e, f; - - memcpy(z, private_key, sizeof(z)); - clamp_key(z); - - for (i = 254; i >= 0; --i) { - r = (z[i >> 3] >> (i & 7)) & 1; - cswap(a, b, r); - cswap(c, d, r); - add(e, a, c); - subtract(a, a, c); - add(c, b, d); - subtract(b, b, d); - multmod(d, e, e); - multmod(f, a, a); - multmod(a, c, a); - multmod(c, b, e); - add(e, a, c); - subtract(a, a, c); - multmod(b, a, a); - subtract(c, d, f); - multmod(a, c, (const fe){ 0xdb41, 1 }); - add(a, a, d); - multmod(c, c, a); - multmod(a, d, f); - multmod(d, b, (const fe){ 9 }); - multmod(b, e, e); - cswap(a, b, r); - cswap(c, d, r); - } - invert(c, c); - multmod(a, a, c); - pack(public_key, a); - - memzero_explicit(&r, sizeof(r)); - memzero_explicit(z, sizeof(z)); - memzero_explicit(a, sizeof(a)); - memzero_explicit(b, sizeof(b)); - memzero_explicit(c, sizeof(c)); - memzero_explicit(d, sizeof(d)); - memzero_explicit(e, sizeof(e)); - memzero_explicit(f, sizeof(f)); -} - -void wg_generate_private_key(wg_key private_key) -{ - wg_generate_preshared_key(private_key); - clamp_key(private_key); -} - -void wg_generate_preshared_key(wg_key preshared_key) -{ - ssize_t ret; - size_t i; - int fd; -#if defined(__OpenBSD__) || (defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12) || (defined(__GLIBC__) && (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25))) - if (!getentropy(preshared_key, sizeof(wg_key))) - return; -#endif -#if defined(__NR_getrandom) && defined(__linux__) - if (syscall(__NR_getrandom, preshared_key, sizeof(wg_key), 0) == sizeof(wg_key)) - return; -#endif - fd = open("/dev/urandom", O_RDONLY); - assert(fd >= 0); - for (i = 0; i < sizeof(wg_key); i += ret) { - ret = read(fd, preshared_key + i, sizeof(wg_key) - i); - assert(ret > 0); - } - close(fd); -} diff --git a/wireguard-control-sys/c/wireguard.h b/wireguard-control-sys/c/wireguard.h deleted file mode 100644 index 328fcb4..0000000 --- a/wireguard-control-sys/c/wireguard.h +++ /dev/null @@ -1,105 +0,0 @@ -/* SPDX-License-Identifier: LGPL-2.1+ */ -/* - * Copyright (C) 2015-2020 Jason A. Donenfeld . All Rights Reserved. - */ - -#ifndef WIREGUARD_H -#define WIREGUARD_H - -#include -#include -#include -#include -#include -#include - -typedef uint8_t wg_key[32]; -typedef char wg_key_b64_string[((sizeof(wg_key) + 2) / 3) * 4 + 1]; - -/* Cross platform __kernel_timespec */ -struct timespec64 { - int64_t tv_sec; - int64_t tv_nsec; -}; - -typedef struct wg_allowedip { - uint16_t family; - union { - struct in_addr ip4; - struct in6_addr ip6; - }; - uint8_t cidr; - struct wg_allowedip *next_allowedip; -} wg_allowedip; - -enum wg_peer_flags { - WGPEER_REMOVE_ME = 1U << 0, - WGPEER_REPLACE_ALLOWEDIPS = 1U << 1, - WGPEER_HAS_PUBLIC_KEY = 1U << 2, - WGPEER_HAS_PRESHARED_KEY = 1U << 3, - WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL = 1U << 4 -}; - -typedef union wg_endpoint { - struct sockaddr addr; - struct sockaddr_in addr4; - struct sockaddr_in6 addr6; -} wg_endpoint; - -typedef struct wg_peer { - enum wg_peer_flags flags; - - wg_key public_key; - wg_key preshared_key; - - wg_endpoint endpoint; - - struct timespec64 last_handshake_time; - uint64_t rx_bytes, tx_bytes; - uint16_t persistent_keepalive_interval; - - struct wg_allowedip *first_allowedip, *last_allowedip; - struct wg_peer *next_peer; -} wg_peer; - -enum wg_device_flags { - WGDEVICE_REPLACE_PEERS = 1U << 0, - WGDEVICE_HAS_PRIVATE_KEY = 1U << 1, - WGDEVICE_HAS_PUBLIC_KEY = 1U << 2, - WGDEVICE_HAS_LISTEN_PORT = 1U << 3, - WGDEVICE_HAS_FWMARK = 1U << 4 -}; - -typedef struct wg_device { - char name[IFNAMSIZ]; - uint32_t ifindex; - - enum wg_device_flags flags; - - wg_key public_key; - wg_key private_key; - - uint32_t fwmark; - uint16_t listen_port; - - struct wg_peer *first_peer, *last_peer; -} wg_device; - -#define wg_for_each_device_name(__names, __name, __len) for ((__name) = (__names), (__len) = 0; ((__len) = strlen(__name)); (__name) += (__len) + 1) -#define wg_for_each_peer(__dev, __peer) for ((__peer) = (__dev)->first_peer; (__peer); (__peer) = (__peer)->next_peer) -#define wg_for_each_allowedip(__peer, __allowedip) for ((__allowedip) = (__peer)->first_allowedip; (__allowedip); (__allowedip) = (__allowedip)->next_allowedip) - -int wg_set_device(wg_device *dev); -int wg_get_device(wg_device **dev, const char *device_name); -int wg_add_device(const char *device_name); -int wg_del_device(const char *device_name); -void wg_free_device(wg_device *dev); -char *wg_list_device_names(void); /* first\0second\0third\0forth\0last\0\0 */ -void wg_key_to_base64(wg_key_b64_string base64, const wg_key key); -int wg_key_from_base64(wg_key key, const wg_key_b64_string base64); -bool wg_key_is_zero(const wg_key key); -void wg_generate_public_key(wg_key public_key, const wg_key private_key); -void wg_generate_private_key(wg_key private_key); -void wg_generate_preshared_key(wg_key preshared_key); - -#endif diff --git a/wireguard-control-sys/src/lib.rs b/wireguard-control-sys/src/lib.rs deleted file mode 100644 index 0c6218b..0000000 --- a/wireguard-control-sys/src/lib.rs +++ /dev/null @@ -1,8 +0,0 @@ -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(non_snake_case)] -// https://github.com/rust-lang/rust-bindgen/issues/1651 -#![allow(deref_nullptr)] - -#[cfg(target_os = "linux")] -include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/wireguard-control-sys/upgrade.sh b/wireguard-control-sys/upgrade.sh deleted file mode 100755 index 4932867..0000000 --- a/wireguard-control-sys/upgrade.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash -e -# This script modified from https://github.com/rusqlite/rusqlite/blob/master/libsqlite3-sys/upgrade.sh - -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -CUR_DIR=$(pwd -P) -echo "$SCRIPT_DIR" -cd "$SCRIPT_DIR" || { echo "fatal error" >&2; exit 1; } -cargo clean -mkdir -p "$SCRIPT_DIR/../target" - -pushd "$SCRIPT_DIR/c" -curl -O https://raw.githubusercontent.com/WireGuard/wireguard-tools/master/contrib/embeddable-wg-library/wireguard.c -curl -O https://raw.githubusercontent.com/WireGuard/wireguard-tools/master/contrib/embeddable-wg-library/wireguard.h -popd - -# Regenerate bindgen file -rm -f "bindgen-bindings/bindings.rs" -# Just to make sure there is only one bindgen.rs file in target dir -find "$SCRIPT_DIR/../target" -type f -name bindings.rs -exec rm {} \; -cargo build --features "buildtime_bindgen" -find "$SCRIPT_DIR/../target" -type f -name bindings.rs -exec mv {} "$SCRIPT_DIR/bindgen-bindings/bindings.rs" \; - -# Sanity checks -cd "$SCRIPT_DIR" || { echo "fatal error" >&2; exit 1; } -cargo test -echo 'You should increment the version in Cargo.toml' \ No newline at end of file diff --git a/wireguard-control/Cargo.toml b/wireguard-control/Cargo.toml index 999637c..92d0b40 100644 --- a/wireguard-control/Cargo.toml +++ b/wireguard-control/Cargo.toml @@ -13,10 +13,13 @@ version = "1.5.2" base64 = "0.13" hex = "0.4" libc = "0.2" - -[target.'cfg(target_os = "linux")'.dependencies] -wireguard-control-sys = { path = "../wireguard-control-sys", version = "1.5.2" } - -[target.'cfg(not(target_os = "linux"))'.dependencies] rand_core = "0.6" curve25519-dalek = "4.0.0-pre.1" + +[target.'cfg(target_os = "linux")'.dependencies] +netlink-request = { path = "../netlink-request" } +netlink-sys = "0.8" +netlink-packet-core = "0.4" +netlink-packet-generic = "0.3" +netlink-packet-route = "0.10" +netlink-packet-wireguard = { git = "https://github.com/mcginty/netlink", branch = "wireguard-fixes" } diff --git a/wireguard-control/examples/enumerate.rs b/wireguard-control/examples/enumerate.rs new file mode 100644 index 0000000..d6392c0 --- /dev/null +++ b/wireguard-control/examples/enumerate.rs @@ -0,0 +1,11 @@ +use wireguard_control::{Backend, Device}; + +#[cfg(target_os = "linux")] +const BACKEND: Backend = Backend::Kernel; +#[cfg(not(target_os = "linux"))] +const BACKEND: Backend = Backend::Userspace; + +fn main() { + let devices = Device::list(BACKEND).unwrap(); + println!("{:?}", devices); +} diff --git a/wireguard-control/src/backends/kernel.rs b/wireguard-control/src/backends/kernel.rs index 0e4663e..aea685a 100644 --- a/wireguard-control/src/backends/kernel.rs +++ b/wireguard-control/src/backends/kernel.rs @@ -1,536 +1,267 @@ use crate::{ - device::AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, InvalidInterfaceName, - InvalidKey, PeerConfig, PeerConfigBuilder, PeerInfo, PeerStats, + device::AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfig, + PeerConfigBuilder, PeerInfo, PeerStats, }; -use wireguard_control_sys::{timespec64, wg_device_flags as wgdf, wg_peer_flags as wgpf}; - -use std::{ - ffi::{CStr, CString}, - io, - net::{IpAddr, SocketAddr}, - os::raw::c_char, - ptr, str, - time::{Duration, SystemTime}, +use netlink_packet_core::{ + NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, }; +use netlink_packet_generic::GenlMessage; +use netlink_packet_route::{ + constants::*, + link::{ + self, + nlas::{Info, InfoKind}, + }, + LinkMessage, RtnlMessage, +}; +use netlink_packet_wireguard::{ + self, + constants::{WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME, WGPEER_F_REPLACE_ALLOWEDIPS}, + nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeerAttrs}, + Wireguard, WireguardCmd, +}; +use netlink_request::{netlink_request_genl, netlink_request_rtnl}; -impl<'a> From<&'a wireguard_control_sys::wg_allowedip> for AllowedIp { - fn from(raw: &wireguard_control_sys::wg_allowedip) -> AllowedIp { - let addr = match i32::from(raw.family) { - libc::AF_INET => IpAddr::V4(unsafe { raw.__bindgen_anon_1.ip4.s_addr }.to_be().into()), - libc::AF_INET6 => { - IpAddr::V6(unsafe { raw.__bindgen_anon_1.ip6.__in6_u.__u6_addr8 }.into()) - }, - _ => unreachable!(format!("Unsupported socket family {}!", raw.family)), - }; +use std::{convert::TryFrom, io}; - AllowedIp { - address: addr, - cidr: raw.cidr, - } +macro_rules! get_nla_value { + ($nlas:expr, $e:ident, $v:ident) => { + $nlas.iter().find_map(|attr| match attr { + $e::$v(value) => Some(value), + _ => None, + }) + }; +} + +impl<'a> TryFrom> for AllowedIp { + type Error = io::Error; + + fn try_from(attrs: Vec) -> Result { + let address = *get_nla_value!(attrs, WgAllowedIpAttrs, IpAddr) + .ok_or_else(|| io::ErrorKind::NotFound)?; + let cidr = *get_nla_value!(attrs, WgAllowedIpAttrs, Cidr) + .ok_or_else(|| io::ErrorKind::NotFound)?; + Ok(AllowedIp { address, cidr }) } } -impl<'a> From<&'a wireguard_control_sys::wg_peer> for PeerInfo { - fn from(raw: &wireguard_control_sys::wg_peer) -> PeerInfo { - PeerInfo { +impl AllowedIp { + fn to_attrs(&self) -> Vec { + vec![ + WgAllowedIpAttrs::Family(if self.address.is_ipv4() { + AF_INET + } else { + AF_INET6 + }), + WgAllowedIpAttrs::IpAddr(self.address), + WgAllowedIpAttrs::Cidr(self.cidr), + ] + } +} + +impl PeerConfigBuilder { + fn to_attrs(&self) -> Vec { + let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.0)]; + let mut flags = 0u32; + if let Some(endpoint) = self.endpoint { + attrs.push(WgPeerAttrs::Endpoint(endpoint)); + } + if let Some(ref key) = self.preshared_key { + attrs.push(WgPeerAttrs::PresharedKey(key.0)); + } + if let Some(i) = self.persistent_keepalive_interval { + attrs.push(WgPeerAttrs::PersistentKeepalive(i)); + } + let allowed_ips: Vec<_> = self.allowed_ips.iter().map(AllowedIp::to_attrs).collect(); + attrs.push(WgPeerAttrs::AllowedIps(allowed_ips)); + if self.remove_me { + flags |= WGPEER_F_REMOVE_ME; + } + if self.replace_allowed_ips { + flags |= WGPEER_F_REPLACE_ALLOWEDIPS; + } + if flags != 0 { + attrs.push(WgPeerAttrs::Flags(flags)); + } + attrs + } +} + +impl<'a> TryFrom> for PeerInfo { + type Error = io::Error; + + fn try_from(attrs: Vec) -> Result { + let public_key = get_nla_value!(attrs, WgPeerAttrs, PublicKey) + .map(|key| Key(*key)) + .ok_or(io::ErrorKind::NotFound)?; + let preshared_key = get_nla_value!(attrs, WgPeerAttrs, PresharedKey).map(|key| Key(*key)); + let endpoint = get_nla_value!(attrs, WgPeerAttrs, Endpoint).cloned(); + let persistent_keepalive_interval = + get_nla_value!(attrs, WgPeerAttrs, PersistentKeepalive).cloned(); + let allowed_ips = get_nla_value!(attrs, WgPeerAttrs, AllowedIps) + .cloned() + .unwrap_or_default() + .into_iter() + .map(AllowedIp::try_from) + .collect::, _>>()?; + let last_handshake_time = get_nla_value!(attrs, WgPeerAttrs, LastHandshake).cloned(); + let rx_bytes = get_nla_value!(attrs, WgPeerAttrs, RxBytes) + .cloned() + .unwrap_or_default(); + let tx_bytes = get_nla_value!(attrs, WgPeerAttrs, TxBytes) + .cloned() + .unwrap_or_default(); + Ok(PeerInfo { config: PeerConfig { - public_key: Key::from_raw(raw.public_key), - preshared_key: if (raw.flags & wgpf::WGPEER_HAS_PRESHARED_KEY).0 > 0 { - Some(Key::from_raw(raw.preshared_key)) - } else { - None - }, - endpoint: parse_endpoint(&raw.endpoint), - persistent_keepalive_interval: match raw.persistent_keepalive_interval { - 0 => None, - x => Some(x), - }, - allowed_ips: parse_allowed_ips(raw), + public_key, + preshared_key, + endpoint, + persistent_keepalive_interval, + allowed_ips, __cant_construct_me: (), }, stats: PeerStats { - last_handshake_time: match ( - raw.last_handshake_time.tv_sec, - raw.last_handshake_time.tv_nsec, - ) { - (0, 0) => None, - (s, ns) => Some(SystemTime::UNIX_EPOCH + Duration::new(s as u64, ns as u32)), - }, - rx_bytes: raw.rx_bytes, - tx_bytes: raw.tx_bytes, + last_handshake_time, + rx_bytes, + tx_bytes, }, - } + }) } } -impl<'a> From<&'a wireguard_control_sys::wg_device> for Device { - fn from(raw: &wireguard_control_sys::wg_device) -> Device { - // SAFETY: The name string buffer came directly from wgctrl so its NUL terminated. - let name = unsafe { InterfaceName::from_wg(raw.name) }; - Device { +impl<'a> TryFrom<&'a Wireguard> for Device { + type Error = io::Error; + + fn try_from(wg: &'a Wireguard) -> Result { + let name = get_nla_value!(wg.nlas, WgDeviceAttrs, IfName) + .ok_or_else(|| io::ErrorKind::NotFound)? + .parse()?; + let public_key = get_nla_value!(wg.nlas, WgDeviceAttrs, PublicKey).map(|key| Key(*key)); + let private_key = get_nla_value!(wg.nlas, WgDeviceAttrs, PrivateKey).map(|key| Key(*key)); + let listen_port = get_nla_value!(wg.nlas, WgDeviceAttrs, ListenPort).cloned(); + let fwmark = get_nla_value!(wg.nlas, WgDeviceAttrs, Fwmark).cloned(); + let peers = get_nla_value!(wg.nlas, WgDeviceAttrs, Peers) + .cloned() + .unwrap_or_default() + .into_iter() + .map(PeerInfo::try_from) + .collect::, _>>()?; + Ok(Device { name, - public_key: if (raw.flags & wgdf::WGDEVICE_HAS_PUBLIC_KEY).0 > 0 { - Some(Key::from_raw(raw.public_key)) - } else { - None - }, - private_key: if (raw.flags & wgdf::WGDEVICE_HAS_PRIVATE_KEY).0 > 0 { - Some(Key::from_raw(raw.private_key)) - } else { - None - }, - fwmark: match raw.fwmark { - 0 => None, - x => Some(x), - }, - listen_port: match raw.listen_port { - 0 => None, - x => Some(x), - }, - peers: parse_peers(raw), + public_key, + private_key, + listen_port, + fwmark, + peers, linked_name: None, backend: Backend::Kernel, __cant_construct_me: (), - } + }) } } -fn parse_peers(dev: &wireguard_control_sys::wg_device) -> Vec { - let mut result = Vec::new(); - - let mut current_peer = dev.first_peer; - - if current_peer.is_null() { - return result; - } - - loop { - let peer = unsafe { &*current_peer }; - - result.push(PeerInfo::from(peer)); - - if current_peer == dev.last_peer { - break; - } - current_peer = peer.next_peer; - } - - result -} - -fn parse_allowed_ips(peer: &wireguard_control_sys::wg_peer) -> Vec { - let mut result = Vec::new(); - - let mut current_ip: *mut wireguard_control_sys::wg_allowedip = peer.first_allowedip; - - if current_ip.is_null() { - return result; - } - - loop { - let ip = unsafe { &*current_ip }; - - result.push(AllowedIp::from(ip)); - - if current_ip == peer.last_allowedip { - break; - } - current_ip = ip.next_allowedip; - } - - result -} - -fn parse_endpoint(endpoint: &wireguard_control_sys::wg_endpoint) -> Option { - let addr = unsafe { endpoint.addr }; - match i32::from(addr.sa_family) { - libc::AF_INET => { - let addr4 = unsafe { endpoint.addr4 }; - Some(SocketAddr::new( - IpAddr::V4(u32::from_be(addr4.sin_addr.s_addr).into()), - u16::from_be(addr4.sin_port), - )) - }, - libc::AF_INET6 => { - let addr6 = unsafe { endpoint.addr6 }; - let bytes = unsafe { addr6.sin6_addr.__in6_u.__u6_addr8 }; - Some(SocketAddr::new( - IpAddr::V6(bytes.into()), - u16::from_be(addr6.sin6_port), - )) - }, - 0 => None, - _ => unreachable!(format!("Unsupported socket family: {}!", addr.sa_family)), - } -} - -fn encode_allowedips( - allowed_ips: &[AllowedIp], -) -> ( - *mut wireguard_control_sys::wg_allowedip, - *mut wireguard_control_sys::wg_allowedip, -) { - if allowed_ips.is_empty() { - return (ptr::null_mut(), ptr::null_mut()); - } - - let mut first_ip = ptr::null_mut(); - let mut last_ip: *mut wireguard_control_sys::wg_allowedip = ptr::null_mut(); - - for ip in allowed_ips { - let mut wg_allowedip = Box::new(wireguard_control_sys::wg_allowedip { - family: 0, - __bindgen_anon_1: Default::default(), - cidr: ip.cidr, - next_allowedip: first_ip, - }); - - match ip.address { - IpAddr::V4(a) => { - wg_allowedip.family = libc::AF_INET as u16; - wg_allowedip.__bindgen_anon_1.ip4.s_addr = u32::to_be(a.into()); - }, - IpAddr::V6(a) => { - wg_allowedip.family = libc::AF_INET6 as u16; - wg_allowedip.__bindgen_anon_1.ip6.__in6_u.__u6_addr8 = a.octets(); - }, - } - - first_ip = Box::into_raw(wg_allowedip); - if last_ip.is_null() { - last_ip = first_ip; - } - } - - (first_ip, last_ip) -} - -fn encode_endpoint(endpoint: Option) -> wireguard_control_sys::wg_endpoint { - match endpoint { - Some(SocketAddr::V4(s)) => { - let mut peer = wireguard_control_sys::wg_endpoint::default(); - peer.addr4 = wireguard_control_sys::sockaddr_in { - sin_family: libc::AF_INET as u16, - sin_addr: wireguard_control_sys::in_addr { - s_addr: u32::from_be((*s.ip()).into()), - }, - sin_port: u16::to_be(s.port()), - sin_zero: [0; 8], - }; - peer - }, - Some(SocketAddr::V6(s)) => { - let mut peer = wireguard_control_sys::wg_endpoint::default(); - let in6_addr = wireguard_control_sys::in6_addr__bindgen_ty_1 { - __u6_addr8: s.ip().octets(), - }; - peer.addr6 = wireguard_control_sys::sockaddr_in6 { - sin6_family: libc::AF_INET6 as u16, - sin6_addr: wireguard_control_sys::in6_addr { __in6_u: in6_addr }, - sin6_port: u16::to_be(s.port()), - sin6_flowinfo: 0, - sin6_scope_id: 0, - }; - peer - }, - None => wireguard_control_sys::wg_endpoint::default(), - } -} - -fn encode_peers( - peers: &[PeerConfigBuilder], -) -> ( - *mut wireguard_control_sys::wg_peer, - *mut wireguard_control_sys::wg_peer, -) { - let mut first_peer = ptr::null_mut(); - let mut last_peer: *mut wireguard_control_sys::wg_peer = ptr::null_mut(); - - for peer in peers { - let (first_allowedip, last_allowedip) = encode_allowedips(&peer.allowed_ips); - - let mut wg_peer = Box::new(wireguard_control_sys::wg_peer { - public_key: peer.public_key.0, - preshared_key: wireguard_control_sys::wg_key::default(), - endpoint: encode_endpoint(peer.endpoint), - last_handshake_time: timespec64 { - tv_sec: 0, - tv_nsec: 0, - }, - tx_bytes: 0, - rx_bytes: 0, - persistent_keepalive_interval: 0, - first_allowedip, - last_allowedip, - next_peer: first_peer, - flags: wgpf::WGPEER_HAS_PUBLIC_KEY, - }); - - if let Some(Key(k)) = peer.preshared_key { - wg_peer.flags |= wgpf::WGPEER_HAS_PRESHARED_KEY; - wg_peer.preshared_key = k; - } - - if let Some(n) = peer.persistent_keepalive_interval { - wg_peer.persistent_keepalive_interval = n; - wg_peer.flags |= wgpf::WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL; - } - - if peer.replace_allowed_ips { - wg_peer.flags |= wgpf::WGPEER_REPLACE_ALLOWEDIPS; - } - - if peer.remove_me { - wg_peer.flags |= wgpf::WGPEER_REMOVE_ME; - } - - first_peer = Box::into_raw(wg_peer); - if last_peer.is_null() { - last_peer = first_peer; - } - } - - (first_peer, last_peer) -} - pub fn enumerate() -> Result, io::Error> { - let base = unsafe { wireguard_control_sys::wg_list_device_names() }; + let link_responses = netlink_request_rtnl( + RtnlMessage::GetLink(LinkMessage::default()), + Some(NLM_F_DUMP | NLM_F_REQUEST), + )?; + let links = link_responses + .into_iter() + // Filter out non-link messages + .filter_map(|response| match response { + NetlinkMessage { + payload: NetlinkPayload::InnerMessage(RtnlMessage::NewLink(link)), + .. + } => Some(link), + _ => None, + }) + .filter(|link| { + for nla in link.nlas.iter() { + if let link::nlas::Nla::Info(infos) = nla { + return infos.iter().any(|info| info == &Info::Kind(InfoKind::Wireguard)) + } + } + false + }) + .filter_map(|link| link.nlas.iter().find_map(|nla| match nla { + link::nlas::Nla::IfName(name) => Some(name.clone()), + _ => None, + })) + .filter_map(|name| name.parse().ok()) + .collect::>(); - if base.is_null() { - return Err(io::Error::last_os_error()); + Ok(links) +} + +fn add_del(iface: &InterfaceName, add: bool) -> io::Result<()> { + let mut message = LinkMessage::default(); + message + .nlas + .push(link::nlas::Nla::IfName(iface.as_str_lossy().to_string())); + message.nlas.push(link::nlas::Nla::Info(vec![Info::Kind( + link::nlas::InfoKind::Wireguard, + )])); + let extra_flags = if add { NLM_F_CREATE | NLM_F_EXCL } else { 0 }; + let rtnl_message = if add { + RtnlMessage::NewLink(message) + } else { + RtnlMessage::DelLink(message) + }; + let result = netlink_request_rtnl(rtnl_message, Some(NLM_F_REQUEST | NLM_F_ACK | extra_flags)); + match result { + Err(e) if e.kind() != io::ErrorKind::AlreadyExists => Err(e), + _ => Ok(()), } - - let mut current = base; - let mut result = Vec::new(); - - loop { - let next_dev = unsafe { CStr::from_ptr(current).to_bytes() }; - - let len = next_dev.len(); - - if len == 0 { - break; - } - - current = unsafe { current.add(len + 1) }; - - let interface: InterfaceName = str::from_utf8(next_dev) - .map_err(|_| InvalidInterfaceName::InvalidChars)? - .parse()?; - - result.push(interface); - } - - unsafe { libc::free(base as *mut libc::c_void) }; - - Ok(result) } pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> { - let (first_peer, last_peer) = encode_peers(&builder.peers); - - let result = unsafe { wireguard_control_sys::wg_add_device(iface.as_ptr()) }; - match result { - 0 | -17 => {}, - _ => return Err(io::Error::last_os_error()), - }; - - let mut wg_device = Box::new(wireguard_control_sys::wg_device { - name: iface.into_inner(), - ifindex: 0, - public_key: wireguard_control_sys::wg_key::default(), - private_key: wireguard_control_sys::wg_key::default(), - fwmark: 0, - listen_port: 0, - first_peer, - last_peer, - flags: wgdf(0), - }); - - if let Some(Key(k)) = builder.public_key { - wg_device.public_key = k; - wg_device.flags |= wgdf::WGDEVICE_HAS_PUBLIC_KEY; - } - + add_del(iface, true)?; + let mut nlas = vec![WgDeviceAttrs::IfName(iface.as_str_lossy().to_string())]; if let Some(Key(k)) = builder.private_key { - wg_device.private_key = k; - wg_device.flags |= wgdf::WGDEVICE_HAS_PRIVATE_KEY; + nlas.push(WgDeviceAttrs::PrivateKey(k)); } - if let Some(f) = builder.fwmark { - wg_device.fwmark = f; - wg_device.flags |= wgdf::WGDEVICE_HAS_FWMARK; + nlas.push(WgDeviceAttrs::Fwmark(f)); } - if let Some(f) = builder.listen_port { - wg_device.listen_port = f; - wg_device.flags |= wgdf::WGDEVICE_HAS_LISTEN_PORT; + nlas.push(WgDeviceAttrs::ListenPort(f)); } - if builder.replace_peers { - wg_device.flags |= wgdf::WGDEVICE_REPLACE_PEERS; - } - - let ptr = Box::into_raw(wg_device); - let result = unsafe { wireguard_control_sys::wg_set_device(ptr) }; - - unsafe { wireguard_control_sys::wg_free_device(ptr) }; - - if result == 0 { - Ok(()) - } else { - Err(io::Error::last_os_error()) + nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); } + let peers: Vec> = builder + .peers + .iter() + .map(PeerConfigBuilder::to_attrs) + .collect(); + nlas.push(WgDeviceAttrs::Peers(peers)); + let genlmsg: GenlMessage = GenlMessage::from_payload(Wireguard { + cmd: WireguardCmd::SetDevice, + nlas, + }); + netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?; + Ok(()) } pub fn get_by_name(name: &InterfaceName) -> Result { - let mut device: *mut wireguard_control_sys::wg_device = ptr::null_mut(); + let genlmsg: GenlMessage = GenlMessage::from_payload(Wireguard { + cmd: WireguardCmd::GetDevice, + nlas: vec![WgDeviceAttrs::IfName(name.as_str_lossy().to_string())], + }); + let responses = netlink_request_genl(genlmsg, Some(NLM_F_REQUEST | NLM_F_DUMP | NLM_F_ACK))?; - let result = unsafe { - wireguard_control_sys::wg_get_device( - (&mut device) as *mut _ as *mut *mut wireguard_control_sys::wg_device, - name.as_ptr(), - ) - }; - - let result = if result == 0 { - Ok(Device::from(unsafe { &*device })) - } else { - Err(io::Error::last_os_error()) - }; - - unsafe { wireguard_control_sys::wg_free_device(device) }; - - result + match responses.get(0) { + Some(NetlinkMessage { + payload: NetlinkPayload::InnerMessage(message), + .. + }) => Device::try_from(&message.payload), + _ => Err(io::Error::new( + io::ErrorKind::InvalidData, + "Unexpected netlink payload", + )), + } } pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { - let result = unsafe { wireguard_control_sys::wg_del_device(iface.as_ptr()) }; - - if result == 0 { - Ok(()) - } else { - Err(io::Error::last_os_error()) - } -} - -/// Represents a WireGuard encryption key. -/// -/// WireGuard makes no meaningful distinction between public, -/// private and preshared keys - any sequence of 32 bytes -/// can be used as either of those. -/// -/// This means that you need to be careful when working with -/// `Key`s, especially ones created from external data. -#[cfg(target_os = "linux")] -#[derive(PartialEq, Eq, Clone)] -pub struct Key(wireguard_control_sys::wg_key); - -#[cfg(target_os = "linux")] -impl Key { - /// Creates a new `Key` from raw bytes. - pub fn from_raw(key: wireguard_control_sys::wg_key) -> Self { - Self(key) - } - - /// Generates and returns a new private key. - pub fn generate_private() -> Self { - let mut private_key = wireguard_control_sys::wg_key::default(); - - unsafe { - wireguard_control_sys::wg_generate_private_key(private_key.as_mut_ptr()); - } - - Self(private_key) - } - - /// Generates and returns a new preshared key. - pub fn generate_preshared() -> Self { - let mut preshared_key = wireguard_control_sys::wg_key::default(); - - unsafe { - wireguard_control_sys::wg_generate_preshared_key(preshared_key.as_mut_ptr()); - } - - Self(preshared_key) - } - - /// Generates a public key for this private key. - pub fn generate_public(&self) -> Self { - let mut public_key = wireguard_control_sys::wg_key::default(); - - unsafe { - wireguard_control_sys::wg_generate_public_key( - public_key.as_mut_ptr(), - &self.0 as *const u8 as *mut u8, - ); - } - - Self(public_key) - } - - /// Generates an all-zero key. - pub fn zero() -> Self { - Self(wireguard_control_sys::wg_key::default()) - } - - pub fn as_bytes(&self) -> &[u8] { - &self.0 - } - - /// Converts the key to a standardized base64 representation, as used by the `wg` utility and `wg-quick`. - pub fn to_base64(&self) -> String { - let mut key_b64: wireguard_control_sys::wg_key_b64_string = [0; 45]; - unsafe { - wireguard_control_sys::wg_key_to_base64( - key_b64.as_mut_ptr(), - &self.0 as *const u8 as *mut u8, - ); - - str::from_utf8_unchecked(&*(&key_b64[..44] as *const [c_char] as *const [u8])).into() - } - } - - /// Converts a base64 representation of the key to the raw bytes. - /// - /// This can fail, as not all text input is valid base64 - in this case - /// `Err(InvalidKey)` is returned. - pub fn from_base64(key: &str) -> Result { - let mut decoded = wireguard_control_sys::wg_key::default(); - - let key_str = CString::new(key)?; - let result = unsafe { - wireguard_control_sys::wg_key_from_base64( - decoded.as_mut_ptr(), - key_str.as_ptr() as *mut _, - ) - }; - - if result == 0 { - Ok(Self { 0: decoded }) - } else { - Err(InvalidKey) - } - } - - pub fn from_hex(hex_str: &str) -> Result { - let bytes = hex::decode(hex_str).map_err(|_| InvalidKey)?; - Self::from_base64(&base64::encode(&bytes)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_encode_endpoint() -> Result<(), Box> { - let endpoint = Some("1.2.3.4:51820".parse()?); - let endpoint6: Option = Some("[2001:db8:1::1]:51820".parse()?); - let encoded = encode_endpoint(endpoint); - let encoded6 = encode_endpoint(endpoint6); - assert_eq!(endpoint, parse_endpoint(&encoded)); - assert_eq!(endpoint6, parse_endpoint(&encoded6)); - Ok(()) - } + add_del(iface, false) } diff --git a/wireguard-control/src/backends/userspace.rs b/wireguard-control/src/backends/userspace.rs index 6774aaf..23a3225 100644 --- a/wireguard-control/src/backends/userspace.rs +++ b/wireguard-control/src/backends/userspace.rs @@ -1,6 +1,5 @@ use crate::{Backend, Device, DeviceUpdate, InterfaceName, PeerConfig, PeerInfo, PeerStats}; -#[cfg(target_os = "linux")] use crate::Key; use std::{ @@ -382,131 +381,3 @@ pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> { _ => Err(io::ErrorKind::Other.into()), } } - -/// Represents a WireGuard encryption key. -/// -/// WireGuard makes no meaningful distinction between public, -/// private and preshared keys - any sequence of 32 bytes -/// can be used as either of those. -/// -/// This means that you need to be careful when working with -/// `Key`s, especially ones created from external data. -#[cfg(not(target_os = "linux"))] -#[derive(PartialEq, Eq, Clone)] -pub struct Key([u8; 32]); - -#[cfg(not(target_os = "linux"))] -impl Key { - /// Generates and returns a new private key. - pub fn generate_private() -> Self { - use rand_core::{OsRng, RngCore}; - - let mut bytes = [0u8; 32]; - OsRng.fill_bytes(&mut bytes); - - // Apply key clamping. - bytes[0] &= 248; - bytes[31] &= 127; - bytes[31] |= 64; - Self(bytes) - } - - /// Generates and returns a new preshared key. - pub fn generate_preshared() -> Self { - use rand_core::{OsRng, RngCore}; - - let mut key = [0u8; 32]; - OsRng.fill_bytes(&mut key); - Self(key) - } - - /// Generates a public key for this private key. - pub fn generate_public(&self) -> Self { - use curve25519_dalek::scalar::Scalar; - - use curve25519_dalek::constants::ED25519_BASEPOINT_TABLE; - - // https://github.com/dalek-cryptography/x25519-dalek/blob/1c39ff92e0dfc0b24aa02d694f26f3b9539322a5/src/x25519.rs#L150 - let point = (&ED25519_BASEPOINT_TABLE * &Scalar::from_bits(self.0)).to_montgomery(); - - Self(point.to_bytes()) - } - - /// Generates an all-zero key. - pub fn zero() -> Self { - Self([0u8; 32]) - } - - pub fn as_bytes(&self) -> &[u8] { - &self.0 - } - - /// Converts the key to a standardized base64 representation, as used by the `wg` utility and `wg-quick`. - pub fn to_base64(&self) -> String { - base64::encode(&self.0) - } - - /// Converts a base64 representation of the key to the raw bytes. - /// - /// This can fail, as not all text input is valid base64 - in this case - /// `Err(InvalidKey)` is returned. - pub fn from_base64(key: &str) -> Result { - use crate::InvalidKey; - - let mut key_bytes = [0u8; 32]; - let decoded_bytes = base64::decode(key).map_err(|_| InvalidKey)?; - - if decoded_bytes.len() != 32 { - return Err(InvalidKey); - } - - key_bytes.copy_from_slice(&decoded_bytes[..]); - Ok(Self(key_bytes)) - } - - pub fn from_hex(hex_str: &str) -> Result { - use crate::InvalidKey; - - let mut sized_bytes = [0u8; 32]; - hex::decode_to_slice(hex_str, &mut sized_bytes).map_err(|_| InvalidKey)?; - Ok(Self(sized_bytes)) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_pubkey_generation() { - let privkey = "SGb+ojrRNDuMePufwtIYhXzA//k6wF3R21tEBgKlzlM="; - let pubkey = "DD5yKRfzExcV5+kDnTroDgCU15latdMjiQ59j1hEuk8="; - - let private = Key::from_base64(privkey).unwrap(); - let public = Key::generate_public(&private); - - assert_eq!(public.to_base64(), pubkey); - } - - #[test] - fn test_rng_sanity_private() { - let first = Key::generate_private(); - assert!(first.as_bytes() != [0u8; 32]); - for _ in 0..100_000 { - let key = Key::generate_private(); - assert!(first != key); - assert!(key.as_bytes() != [0u8; 32]); - } - } - - #[test] - fn test_rng_sanity_preshared() { - let first = Key::generate_preshared(); - assert!(first.as_bytes() != [0u8; 32]); - for _ in 0..100_000 { - let key = Key::generate_preshared(); - assert!(first != key); - assert!(key.as_bytes() != [0u8; 32]); - } - } -} diff --git a/wireguard-control/src/config.rs b/wireguard-control/src/config.rs index 1a0fde1..909351b 100644 --- a/wireguard-control/src/config.rs +++ b/wireguard-control/src/config.rs @@ -102,29 +102,34 @@ impl PeerConfigBuilder { } /// Specifies a preshared key to be set for this peer. + #[must_use] pub fn set_preshared_key(mut self, key: Key) -> Self { self.preshared_key = Some(key); self } /// Specifies that this peer's preshared key should be unset. + #[must_use] pub fn unset_preshared_key(self) -> Self { self.set_preshared_key(Key::zero()) } /// Specifies an exact endpoint that this peer should be allowed to connect from. + #[must_use] pub fn set_endpoint(mut self, address: SocketAddr) -> Self { self.endpoint = Some(address); self } /// Specifies the interval between keepalive packets to be sent to this peer. + #[must_use] pub fn set_persistent_keepalive_interval(mut self, interval: u16) -> Self { self.persistent_keepalive_interval = Some(interval); self } /// Specifies that this peer does not require keepalive packets. + #[must_use] pub fn unset_persistent_keepalive(self) -> Self { self.set_persistent_keepalive_interval(0) } @@ -133,6 +138,7 @@ impl PeerConfigBuilder { /// /// See [`AllowedIp`](AllowedIp) for details. This method can be called /// more than once, and all IP addresses will be added to the configuration. + #[must_use] pub fn add_allowed_ip(mut self, address: IpAddr, cidr: u8) -> Self { self.allowed_ips.push(AllowedIp { address, cidr }); self @@ -142,6 +148,7 @@ impl PeerConfigBuilder { /// /// See [`AllowedIp`](AllowedIp) for details. This method can be called /// more than once, and all IP addresses will be added to the configuration. + #[must_use] pub fn add_allowed_ips(mut self, ips: &[AllowedIp]) -> Self { self.allowed_ips.extend_from_slice(ips); self @@ -151,6 +158,7 @@ impl PeerConfigBuilder { /// /// This is a convenience method for cases when you want to connect to a server /// that all traffic should be routed through. + #[must_use] pub fn allow_all_ips(self) -> Self { self.add_allowed_ip(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) .add_allowed_ip(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) @@ -158,12 +166,14 @@ impl PeerConfigBuilder { /// Specifies that the allowed IP addresses in this configuration should replace /// the existing configuration of the interface, not be appended to it. + #[must_use] pub fn replace_allowed_ips(mut self) -> Self { self.replace_allowed_ips = true; self } /// Mark peer for removal from interface. + #[must_use] pub fn remove(mut self) -> Self { self.remove_me = true; self diff --git a/wireguard-control/src/device.rs b/wireguard-control/src/device.rs index acc42a3..b07e37b 100644 --- a/wireguard-control/src/device.rs +++ b/wireguard-control/src/device.rs @@ -150,16 +150,6 @@ impl FromStr for InterfaceName { } impl InterfaceName { - #[cfg(target_os = "linux")] - /// Creates a new [InterfaceName](Self). - /// - /// ## Safety - /// - /// The caller must ensure that `name` is a valid C string terminated by a NUL. - pub(crate) unsafe fn from_wg(name: RawInterfaceName) -> Self { - Self(name) - } - /// Returns a human-readable form of the device name. /// /// Only use this when the interface name was constructed from a Rust string. @@ -173,12 +163,6 @@ impl InterfaceName { pub fn as_ptr(&self) -> *const c_char { self.0.as_ptr() } - - #[cfg(target_os = "linux")] - /// Consumes this interface name, returning its raw byte buffer. - pub(crate) fn into_inner(self) -> RawInterfaceName { - self.0 - } } impl fmt::Debug for InterfaceName { @@ -307,6 +291,7 @@ pub struct DeviceUpdate { impl DeviceUpdate { /// Creates a new `DeviceConfigBuilder` that does nothing when applied. + #[must_use] pub fn new() -> Self { DeviceUpdate { public_key: None, @@ -323,40 +308,47 @@ impl DeviceUpdate { /// This is a convenience method that simply wraps /// [`set_public_key`](DeviceConfigBuilder::set_public_key) /// and [`set_private_key`](DeviceConfigBuilder::set_private_key). + #[must_use] pub fn set_keypair(self, keypair: KeyPair) -> Self { self.set_public_key(keypair.public) .set_private_key(keypair.private) } /// Specifies a new public key to be applied to the interface. + #[must_use] pub fn set_public_key(mut self, key: Key) -> Self { self.public_key = Some(key); self } /// Specifies that the public key for this interface should be unset. + #[must_use] pub fn unset_public_key(self) -> Self { self.set_public_key(Key::zero()) } /// Sets a new private key to be applied to the interface. + #[must_use] pub fn set_private_key(mut self, key: Key) -> Self { self.private_key = Some(key); self } /// Specifies that the private key for this interface should be unset. + #[must_use] pub fn unset_private_key(self) -> Self { self.set_private_key(Key::zero()) } /// Specifies the fwmark value that should be applied to packets coming from the interface. + #[must_use] pub fn set_fwmark(mut self, fwmark: u32) -> Self { self.fwmark = Some(fwmark); self } /// Specifies that fwmark should not be set on packets from the interface. + #[must_use] pub fn unset_fwmark(self) -> Self { self.set_fwmark(0) } @@ -364,6 +356,7 @@ impl DeviceUpdate { /// Specifies the port to listen for incoming packets on. /// /// This is useful for a server configuration that listens on a fixed endpoint. + #[must_use] pub fn set_listen_port(mut self, port: u16) -> Self { self.listen_port = Some(port); self @@ -372,6 +365,7 @@ impl DeviceUpdate { /// Specifies that a random port should be used for incoming packets. /// /// This is probably what you want in client configurations. + #[must_use] pub fn randomize_listen_port(self) -> Self { self.set_listen_port(0) } @@ -381,6 +375,7 @@ impl DeviceUpdate { /// See [`PeerConfigBuilder`](PeerConfigBuilder) for details on building /// peer configurations. This method can be called more than once, and all /// peers will be added to the configuration. + #[must_use] pub fn add_peer(mut self, peer: PeerConfigBuilder) -> Self { self.peers.push(peer); self @@ -391,6 +386,7 @@ impl DeviceUpdate { /// This is simply a convenience method to make adding peers more fluent. /// This method can be called more than once, and all peers will be added /// to the configuration. + #[must_use] pub fn add_peer_with( self, pubkey: &Key, @@ -400,6 +396,7 @@ impl DeviceUpdate { } /// Specifies multiple peer configurations to be added to the interface. + #[must_use] pub fn add_peers(mut self, peers: &[PeerConfigBuilder]) -> Self { self.peers.extend_from_slice(peers); self @@ -407,12 +404,14 @@ impl DeviceUpdate { /// Specifies that the peer configurations in this `DeviceConfigBuilder` should /// replace the existing configurations on the interface, not modify or append to them. + #[must_use] pub fn replace_peers(mut self) -> Self { self.replace_peers = true; self } /// Specifies that the peer with this public key should be removed from the interface. + #[must_use] pub fn remove_peer_by_key(self, public_key: &Key) -> Self { let mut peer = PeerConfigBuilder::new(public_key); peer.remove_me = true; diff --git a/wireguard-control/src/key.rs b/wireguard-control/src/key.rs index d8c707d..55ee368 100644 --- a/wireguard-control/src/key.rs +++ b/wireguard-control/src/key.rs @@ -1,4 +1,3 @@ -use crate::backends; use std::{ffi::NulError, fmt}; /// Represents an error in base64 key parsing. @@ -27,11 +26,122 @@ impl From for InvalidKey { /// /// This means that you need to be careful when working with /// `Key`s, especially ones created from external data. -#[cfg(not(target_os = "linux"))] -pub use backends::userspace::Key; +#[derive(PartialEq, Eq, Clone)] +pub struct Key(pub [u8; 32]); -#[cfg(target_os = "linux")] -pub use backends::kernel::Key; +impl Key { + /// Generates and returns a new private key. + pub fn generate_private() -> Self { + use rand_core::{OsRng, RngCore}; + + let mut bytes = [0u8; 32]; + OsRng.fill_bytes(&mut bytes); + + // Apply key clamping. + bytes[0] &= 248; + bytes[31] &= 127; + bytes[31] |= 64; + Self(bytes) + } + + /// Generates and returns a new preshared key. + #[must_use] + pub fn generate_preshared() -> Self { + use rand_core::{OsRng, RngCore}; + + let mut key = [0u8; 32]; + OsRng.fill_bytes(&mut key); + Self(key) + } + + /// Generates a public key for this private key. + #[must_use] + pub fn generate_public(&self) -> Self { + use curve25519_dalek::scalar::Scalar; + + use curve25519_dalek::constants::ED25519_BASEPOINT_TABLE; + + // https://github.com/dalek-cryptography/x25519-dalek/blob/1c39ff92e0dfc0b24aa02d694f26f3b9539322a5/src/x25519.rs#L150 + let point = (&ED25519_BASEPOINT_TABLE * &Scalar::from_bits(self.0)).to_montgomery(); + + Self(point.to_bytes()) + } + + /// Generates an all-zero key. + #[must_use] + pub fn zero() -> Self { + Self([0u8; 32]) + } + + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Converts the key to a standardized base64 representation, as used by the `wg` utility and `wg-quick`. + pub fn to_base64(&self) -> String { + base64::encode(&self.0) + } + + /// Converts a base64 representation of the key to the raw bytes. + /// + /// This can fail, as not all text input is valid base64 - in this case + /// `Err(InvalidKey)` is returned. + pub fn from_base64(key: &str) -> Result { + let mut key_bytes = [0u8; 32]; + let decoded_bytes = base64::decode(key).map_err(|_| InvalidKey)?; + + if decoded_bytes.len() != 32 { + return Err(InvalidKey); + } + + key_bytes.copy_from_slice(&decoded_bytes[..]); + Ok(Self(key_bytes)) + } + + pub fn from_hex(hex_str: &str) -> Result { + let mut sized_bytes = [0u8; 32]; + hex::decode_to_slice(hex_str, &mut sized_bytes).map_err(|_| InvalidKey)?; + Ok(Self(sized_bytes)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_pubkey_generation() { + let privkey = "SGb+ojrRNDuMePufwtIYhXzA//k6wF3R21tEBgKlzlM="; + let pubkey = "DD5yKRfzExcV5+kDnTroDgCU15latdMjiQ59j1hEuk8="; + + let private = Key::from_base64(privkey).unwrap(); + let public = Key::generate_public(&private); + + assert_eq!(public.to_base64(), pubkey); + } + + #[test] + fn test_rng_sanity_private() { + let first = Key::generate_private(); + assert!(first.as_bytes() != [0u8; 32]); + for _ in 0..100_000 { + let key = Key::generate_private(); + assert!(first != key); + assert!(key.as_bytes() != [0u8; 32]); + } + } + + #[test] + fn test_rng_sanity_preshared() { + let first = Key::generate_preshared(); + assert!(first.as_bytes() != [0u8; 32]); + for _ in 0..100_000 { + let key = Key::generate_preshared(); + assert!(first != key); + assert!(key.as_bytes() != [0u8; 32]); + } + } +} /// Represents a pair of private and public keys. ///