From 8903604caa041db3e97733a74a83ce2e63563933 Mon Sep 17 00:00:00 2001 From: Jake McGinty Date: Wed, 1 Sep 2021 18:58:46 +0900 Subject: [PATCH] NAT traversal: ICE-esque candidate selection (#134) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change adds the ability for peers to report additional candidate endpoints for other peers to attempt connections with outside of the endpoint reported by the coordinating server. While not a complete solution to the full spectrum of NAT traversal issues (TURN-esque proxying is still notably missing), it allows peers within the same NAT to connect to each other via their LAN addresses, which is a win nonetheless. In the future, more advanced candidate discovery could be used to punch through additional types of NAT cone types as well. Co-authored-by: Matěj Laitl --- Cargo.lock | 27 +++++++- client/src/data_store.rs | 11 ++-- client/src/main.rs | 121 ++++++++++++++++-------------------- client/src/nat.rs | 125 +++++++++++++++++++++++++++++++++++++ client/src/util.rs | 61 +++++++++++++++--- server/src/api/user.rs | 74 +++++++++++++++++++--- server/src/db/mod.rs | 10 ++- server/src/db/peer.rs | 73 ++++++++++++++-------- server/src/initialize.rs | 2 + server/src/main.rs | 2 + server/src/test.rs | 1 + shared/Cargo.toml | 3 + shared/src/netlink.rs | 130 +++++++++++++++++++++++++++++++++++---- shared/src/prompts.rs | 1 + shared/src/types.rs | 33 +++++----- shared/src/wg.rs | 99 ++++++++++++++++++++++++++++- 16 files changed, 623 insertions(+), 150 deletions(-) create mode 100644 client/src/nat.rs diff --git a/Cargo.lock b/Cargo.lock index 71b8636..87100ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,9 +525,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.98" +version = "0.2.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" +checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" [[package]] name = "libsqlite3-sys" @@ -569,6 +569,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" +[[package]] +name = "memoffset" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9" +dependencies = [ + "autocfg", +] + [[package]] name = "mio" version = "0.7.13" @@ -639,6 +648,19 @@ dependencies = [ "log", ] +[[package]] +name = "nix" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7555d6c7164cc913be1ce7f95cbecdabda61eb2ccd89008524af306fb7f5031" +dependencies = [ + "bitflags", + "cc", + "cfg-if", + "libc", + "memoffset", +] + [[package]] name = "nom" version = "6.1.2" @@ -999,6 +1021,7 @@ dependencies = [ "netlink-packet-core", "netlink-packet-route", "netlink-sys", + "nix", "publicip", "regex", "serde", diff --git a/client/src/data_store.rs b/client/src/data_store.rs index df582a2..1736156 100644 --- a/client/src/data_store.rs +++ b/client/src/data_store.rs @@ -80,7 +80,7 @@ impl DataStore { /// /// Note, however, that this does not prevent a compromised server from adding a new /// peer under its control, of course. - pub fn update_peers(&mut self, current_peers: Vec) -> Result<(), Error> { + pub fn update_peers(&mut self, current_peers: &[Peer]) -> Result<(), Error> { let peers = match &mut self.contents { Contents::V1 { ref mut peers, .. } => peers, }; @@ -149,6 +149,7 @@ mod tests { is_redeemed: true, persistent_keepalive_interval: None, invite_expires: None, + candidates: vec![], } }]; static ref BASE_CIDRS: Vec = vec![Cidr { @@ -167,7 +168,7 @@ mod tests { assert_eq!(0, store.peers().len()); assert_eq!(0, store.cidrs().len()); - store.update_peers(BASE_PEERS.to_owned()).unwrap(); + store.update_peers(&BASE_PEERS).unwrap(); store.set_cidrs(BASE_CIDRS.to_owned()); store.write().unwrap(); } @@ -189,13 +190,13 @@ mod tests { DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap(); // Should work, since peer is unmodified. - store.update_peers(BASE_PEERS.clone()).unwrap(); + store.update_peers(&BASE_PEERS).unwrap(); let mut modified = BASE_PEERS.clone(); modified[0].contents.public_key = "foo".to_string(); // Should NOT work, since peer is unmodified. - assert!(store.update_peers(modified).is_err()); + assert!(store.update_peers(&modified).is_err()); } #[test] @@ -206,7 +207,7 @@ mod tests { DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap(); // Should work, since peer is unmodified. - store.update_peers(vec![]).unwrap(); + store.update_peers(&[]).unwrap(); let new_peers = BASE_PEERS .iter() .cloned() diff --git a/client/src/main.rs b/client/src/main.rs index e411883..bc67f86 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -4,13 +4,17 @@ use dialoguer::{Confirm, Input}; use hostsfile::HostsBuilder; use indoc::eprintdoc; use shared::{ - interface_config::InterfaceConfig, prompts, AddAssociationOpts, AddCidrOpts, AddPeerOpts, - Association, AssociationContents, Cidr, CidrTree, DeleteCidrOpts, EndpointContents, - InstallOpts, Interface, IoErrorContext, NetworkOpt, Peer, PeerDiff, RedeemContents, - RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR, REDEEM_TRANSITION_WAIT, + interface_config::InterfaceConfig, + prompts, + wg::{DeviceExt, PeerInfoExt}, + AddAssociationOpts, AddCidrOpts, AddPeerOpts, Association, AssociationContents, Cidr, CidrTree, + DeleteCidrOpts, Endpoint, EndpointContents, InstallOpts, Interface, IoErrorContext, NetworkOpt, + Peer, RedeemContents, RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR, + REDEEM_TRANSITION_WAIT, }; use std::{ fmt, io, + net::SocketAddr, path::{Path, PathBuf}, thread, time::Duration, @@ -19,9 +23,11 @@ use structopt::{clap::AppSettings, StructOpt}; use wgctrl::{Device, DeviceUpdate, InterfaceName, PeerConfigBuilder, PeerInfo}; mod data_store; +mod nat; mod util; use data_store::DataStore; +use nat::NatTraverse; use shared::{wg, Error}; use util::{human_duration, human_size, Api}; @@ -484,70 +490,13 @@ fn fetch( let mut store = DataStore::open_or_create(interface)?; let State { peers, cidrs } = Api::new(&config.server).http("GET", "/user/state")?; - let device_info = Device::get(interface, network.backend).with_str(interface.as_str_lossy())?; - let interface_public_key = device_info - .public_key - .as_ref() - .map(|k| k.to_base64()) - .unwrap_or_default(); - let existing_peers = &device_info.peers; - - // Match existing peers (by pubkey) to new peer information from the server. - let modifications = peers.iter().filter_map(|peer| { - if peer.is_disabled || peer.public_key == interface_public_key { - None - } else { - let existing_peer = existing_peers - .iter() - .find(|p| p.config.public_key.to_base64() == peer.public_key); - PeerDiff::new(existing_peer, Some(peer)).unwrap() - } - }); - - // Remove any peers on the interface that aren't in the server's peer list any more. - let removals = existing_peers.iter().filter_map(|existing| { - let public_key = existing.config.public_key.to_base64(); - if peers.iter().any(|p| p.public_key == public_key) { - None - } else { - PeerDiff::new(Some(existing), None).unwrap() - } - }); + let device = Device::get(interface, network.backend)?; + let modifications = device.diff(&peers); let updates = modifications - .chain(removals) - .inspect(|diff| { - let public_key = diff.public_key().to_base64(); - - let text = match (diff.old, diff.new) { - (None, Some(_)) => "added".green(), - (Some(_), Some(_)) => "modified".yellow(), - (Some(_), None) => "removed".red(), - _ => unreachable!("PeerDiff can't be None -> None"), - }; - - // Grab the peer name from either the new data, or the historical data (if the peer is removed). - let peer_hostname = match diff.new { - Some(peer) => Some(peer.name.clone()), - _ => store - .peers() - .iter() - .find(|p| p.public_key == public_key) - .map(|p| p.name.clone()), - }; - let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]"); - - log::info!( - " peer {} ({}...) was {}.", - peer_name.yellow(), - &public_key[..10].dimmed(), - text - ); - - for change in diff.changes() { - log::debug!(" {}", change); - } - }) + .iter() + .inspect(|diff| util::print_peer_diff(&store, diff)) + .cloned() .map(PeerConfigBuilder::from) .collect::>(); @@ -566,10 +515,44 @@ fn fetch( } else { log::info!("{}", "peers are already up to date.".green()); } + store.set_cidrs(cidrs); - store.update_peers(peers)?; + store.update_peers(&peers)?; store.write().with_str(interface.to_string())?; + let candidates = wg::get_local_addrs()? + .into_iter() + .map(|addr| SocketAddr::from((addr, device.listen_port.unwrap_or(51820))).into()) + .take(10) + .collect::>(); + log::info!( + "reporting {} interface address{} as NAT traversal candidates...", + candidates.len(), + if candidates.len() == 1 { "" } else { "es" } + ); + log::debug!("candidates: {:?}", candidates); + match Api::new(&config.server).http_form::<_, ()>("PUT", "/user/candidates", &candidates) { + Err(ureq::Error::Status(404, _)) => { + log::warn!("your network is using an old version of innernet-server that doesn't support NAT traversal candidate reporting.") + }, + Err(e) => return Err(e.into()), + _ => {}, + } + + log::debug!("viable ICE candidates: {:?}", candidates); + + let mut nat_traverse = NatTraverse::new(interface, network.backend, &modifications)?; + loop { + if nat_traverse.is_finished() { + break; + } + log::info!( + "Attempting to establish connection with {} remaining unconnected peers...", + nat_traverse.remaining() + ); + nat_traverse.step()?; + } + Ok(()) } @@ -992,7 +975,9 @@ fn print_peer(peer: &PeerState, short: bool, level: usize) { let pad = level * 2; let PeerState { peer, info } = peer; if short { - let connected = PeerDiff::peer_recently_connected(info); + let connected = info + .map(|info| !info.is_recently_connected()) + .unwrap_or_default(); println_pad!( pad, diff --git a/client/src/nat.rs b/client/src/nat.rs new file mode 100644 index 0000000..b7d22d5 --- /dev/null +++ b/client/src/nat.rs @@ -0,0 +1,125 @@ +//! ICE-like NAT traversal logic. +//! +//! Doesn't follow the specific ICE protocol, but takes great inspiration from RFC 8445 +//! and applies it to a protocol more specific to innernet. + +use std::time::{Duration, Instant}; + +use anyhow::Error; +use shared::{ + wg::{DeviceExt, PeerInfoExt}, + Endpoint, Peer, PeerDiff, +}; +use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder}; + +const STEP_INTERVAL: Duration = Duration::from_secs(5); + +pub struct NatTraverse<'a> { + interface: &'a InterfaceName, + backend: Backend, + remaining: Vec, +} + +impl<'a> NatTraverse<'a> { + pub fn new(interface: &'a InterfaceName, backend: Backend, diffs: &[PeerDiff]) -> Result { + let mut remaining: Vec<_> = diffs.iter().filter_map(|diff| diff.new).cloned().collect(); + + for peer in &mut remaining { + // Limit reported alternative candidates to 10. + peer.candidates.truncate(10); + + // remove server-reported endpoint from elsewhere in the list if it existed. + let endpoint = peer.endpoint.clone(); + peer.candidates + .retain(|addr| Some(addr) != endpoint.as_ref()); + } + let mut nat_traverse = Self { + interface, + backend, + remaining, + }; + nat_traverse.refresh_remaining()?; + Ok(nat_traverse) + } + + pub fn is_finished(&self) -> bool { + self.remaining.is_empty() + } + + pub fn remaining(&self) -> usize { + self.remaining.len() + } + + /// Refreshes the current state of candidate traversal attempts, returning + /// the peers that have been exhausted of all options (not included are + /// peers that have successfully connected, or peers removed from the interface). + fn refresh_remaining(&mut self) -> Result, Error> { + let device = Device::get(self.interface, self.backend)?; + // Remove connected and missing peers + self.remaining.retain(|peer| { + if let Some(peer_info) = device.get_peer(&peer.public_key) { + let recently_connected = peer_info.is_recently_connected(); + if recently_connected { + log::debug!( + "peer {} removed from NAT traverser (connected!).", + peer.name + ); + } + !recently_connected + } else { + log::debug!( + "peer {} removed from NAT traverser (no longer on interface).", + peer.name + ); + false + } + }); + let (exhausted, remaining): (Vec<_>, Vec<_>) = self + .remaining + .drain(..) + .partition(|peer| peer.candidates.is_empty()); + self.remaining = remaining; + Ok(exhausted) + } + + pub fn step(&mut self) -> Result<(), Error> { + let exhuasted = self.refresh_remaining()?; + + // Reset peer endpoints that had no viable candidates back to the server-reported one, if it exists. + let reset_updates = exhuasted + .into_iter() + .filter_map(|peer| set_endpoint(&peer.public_key, peer.endpoint.as_ref())); + + // Set all peers' endpoints to their next available candidate. + let candidate_updates = self.remaining.iter_mut().filter_map(|peer| { + let endpoint = peer.candidates.pop(); + set_endpoint(&peer.public_key, endpoint.as_ref()) + }); + + let updates: Vec<_> = reset_updates.chain(candidate_updates).collect(); + + DeviceUpdate::new() + .add_peers(&updates) + .apply(self.interface, self.backend)?; + + let start = Instant::now(); + while start.elapsed() < STEP_INTERVAL { + self.refresh_remaining()?; + if self.is_finished() { + log::debug!("NAT traverser is finished!"); + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + Ok(()) + } +} + +/// Return a PeerConfigBuilder if an endpoint exists and resolves successfully. +fn set_endpoint(public_key: &str, endpoint: Option<&Endpoint>) -> Option { + endpoint + .and_then(|endpoint| endpoint.resolve().ok()) + .map(|addr| { + PeerConfigBuilder::new(&Key::from_base64(public_key).unwrap()).set_endpoint(addr) + }) +} diff --git a/client/src/util.rs b/client/src/util.rs index 310fbe8..d853685 100644 --- a/client/src/util.rs +++ b/client/src/util.rs @@ -1,9 +1,9 @@ -use crate::{ClientError, Error}; +use crate::data_store::DataStore; use colored::*; use indoc::eprintdoc; use log::{Level, LevelFilter}; use serde::{de::DeserializeOwned, Serialize}; -use shared::{interface_config::ServerInfo, INNERNET_PUBKEY_HEADER}; +use shared::{interface_config::ServerInfo, PeerDiff, INNERNET_PUBKEY_HEADER}; use std::{io, time::Duration}; use ureq::{Agent, AgentBuilder}; @@ -137,6 +137,39 @@ pub fn permissions_helptext(e: &io::Error) { } } +pub fn print_peer_diff(store: &DataStore, diff: &PeerDiff) { + let public_key = diff.public_key().to_base64(); + + let text = match (diff.old, diff.new) { + (None, Some(_)) => "added".green(), + (Some(_), Some(_)) => "modified".yellow(), + (Some(_), None) => "removed".red(), + _ => unreachable!("PeerDiff can't be None -> None"), + }; + + // Grab the peer name from either the new data, or the historical data (if the peer is removed). + let peer_hostname = match diff.new { + Some(peer) => Some(peer.name.clone()), + None => store + .peers() + .iter() + .find(|p| p.public_key == public_key) + .map(|p| p.name.clone()), + }; + let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]"); + + log::info!( + " peer {} ({}...) was {}.", + peer_name.yellow(), + &public_key[..10].dimmed(), + text + ); + + for change in diff.changes() { + log::debug!(" {}", change); + } +} + pub struct Api<'a> { agent: Agent, server: &'a ServerInfo, @@ -151,7 +184,7 @@ impl<'a> Api<'a> { Self { agent, server } } - pub fn http(&self, verb: &str, endpoint: &str) -> Result { + pub fn http(&self, verb: &str, endpoint: &str) -> Result { self.request::<(), _>(verb, endpoint, None) } @@ -160,7 +193,7 @@ impl<'a> Api<'a> { verb: &str, endpoint: &str, form: S, - ) -> Result { + ) -> Result { self.request(verb, endpoint, Some(form)) } @@ -169,7 +202,7 @@ impl<'a> Api<'a> { verb: &str, endpoint: &str, form: Option, - ) -> Result { + ) -> Result { let request = self .agent .request( @@ -179,7 +212,12 @@ impl<'a> Api<'a> { .set(INNERNET_PUBKEY_HEADER, &self.server.public_key); let response = if let Some(form) = form { - request.send_json(serde_json::to_value(form)?)? + request.send_json(serde_json::to_value(form).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to serialize JSON request: {}", e), + ) + })?)? } else { request.call()? }; @@ -190,10 +228,13 @@ impl<'a> Api<'a> { response = "null".into(); } Ok(serde_json::from_str(&response).map_err(|e| { - ClientError(format!( - "failed to deserialize JSON response from the server: {}, response={}", - e, &response - )) + io::Error::new( + io::ErrorKind::InvalidData, + format!( + "failed to deserialize JSON response from the server: {}, response={}", + e, &response + ), + ) })?) } } diff --git a/server/src/api/user.rs b/server/src/api/user.rs index 40c189c..ffedd76 100644 --- a/server/src/api/user.rs +++ b/server/src/api/user.rs @@ -36,11 +36,20 @@ pub async fn routes( let form = form_body(req).await?; handlers::endpoint(form, session).await }, + (&Method::PUT, Some("candidates")) => { + if !session.user_capable() { + return Err(ServerError::Unauthorized); + } + let form = form_body(req).await?; + handlers::candidates(form, session).await + }, _ => Err(ServerError::NotFound), } } mod handlers { + use shared::Endpoint; + use super::*; /// Get the current state of the network, in the eyes of the current peer. @@ -115,14 +124,29 @@ mod handlers { status_response(StatusCode::NO_CONTENT) } - /// Redeems an invitation. An invitation includes a WireGuard keypair generated by either the server - /// or a peer with admin rights. - /// - /// Redemption is the process of an invitee generating their own keypair and exchanging their temporary - /// key with their permanent one. - /// - /// Until this API endpoint is called, the invited peer will not show up to other peers, and once - /// it is called and succeeds, it cannot be called again. + /// Report any other endpoint candidates that can be tried by peers to connect. + /// Currently limited to 10 candidates max. + pub async fn candidates( + contents: Vec, + session: Session, + ) -> Result, ServerError> { + if contents.len() > 10 { + return status_response(StatusCode::PAYLOAD_TOO_LARGE); + } + let conn = session.context.db.lock(); + let mut selected_peer = DatabasePeer::get(&conn, session.peer.id)?; + selected_peer.update( + &conn, + PeerContents { + candidates: contents, + ..selected_peer.contents.clone() + }, + )?; + + status_response(StatusCode::NO_CONTENT) + } + + /// Force a specific endpoint to be reported by the server. pub async fn endpoint( contents: EndpointContents, session: Session, @@ -148,7 +172,7 @@ mod tests { use super::*; use crate::{db::DatabaseAssociation, test}; use bytes::Buf; - use shared::{AssociationContents, CidrContents, EndpointContents, Error}; + use shared::{AssociationContents, CidrContents, Endpoint, EndpointContents, Error}; #[tokio::test] async fn test_get_state_from_developer1() -> Result<(), Error> { @@ -406,4 +430,36 @@ mod tests { assert_eq!(res.status(), StatusCode::UNAUTHORIZED); Ok(()) } + + #[tokio::test] + async fn test_candidates() -> Result<(), Error> { + let server = test::Server::new()?; + + let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?; + assert_eq!(peer.candidates, vec![]); + + let candidates = vec!["1.1.1.1:51820".parse::().unwrap()]; + assert_eq!( + server + .form_request( + test::DEVELOPER1_PEER_IP, + "PUT", + "/v1/user/candidates", + &candidates + ) + .await + .status(), + StatusCode::NO_CONTENT + ); + + let res = server + .request(test::DEVELOPER1_PEER_IP, "GET", "/v1/user/state") + .await; + + assert_eq!(res.status(), StatusCode::OK); + + let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?; + assert_eq!(peer.candidates, candidates); + Ok(()) + } } diff --git a/server/src/db/mod.rs b/server/src/db/mod.rs index 60a7912..21ae8c9 100644 --- a/server/src/db/mod.rs +++ b/server/src/db/mod.rs @@ -8,11 +8,13 @@ pub use peer::DatabasePeer; use rusqlite::params; const INVITE_EXPIRATION_VERSION: usize = 1; +const ENDPOINT_CANDIDATES_VERSION: usize = 2; -pub const CURRENT_VERSION: usize = INVITE_EXPIRATION_VERSION; +pub const CURRENT_VERSION: usize = ENDPOINT_CANDIDATES_VERSION; pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { let old_version: usize = conn.pragma_query_value(None, "user_version", |r| r.get(0))?; + log::debug!("user_version: {}", old_version); if old_version < INVITE_EXPIRATION_VERSION { conn.execute( @@ -21,8 +23,12 @@ pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> )?; } - conn.pragma_update(None, "user_version", &CURRENT_VERSION)?; + if old_version < ENDPOINT_CANDIDATES_VERSION { + conn.execute("ALTER TABLE peers ADD COLUMN candidates TEXT", params![])?; + } + if old_version != CURRENT_VERSION { + conn.pragma_update(None, "user_version", &CURRENT_VERSION)?; log::info!( "migrated db version from {} to {}", old_version, diff --git a/server/src/db/peer.rs b/server/src/db/peer.rs index 5f729a0..5bf7973 100644 --- a/server/src/db/peer.rs +++ b/server/src/db/peer.rs @@ -2,8 +2,8 @@ use super::DatabaseCidr; use crate::ServerError; use lazy_static::lazy_static; use regex::Regex; -use rusqlite::{params, Connection}; -use shared::{Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS}; +use rusqlite::{params, types::Type, Connection}; +use shared::{Endpoint, Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS}; use std::{ net::IpAddr, ops::{Deref, DerefMut}, @@ -22,12 +22,27 @@ pub static CREATE_TABLE_SQL: &str = "CREATE TABLE peers ( is_disabled INTEGER DEFAULT 0 NOT NULL, /* Is the peer disabled? (peers cannot be deleted) */ is_redeemed INTEGER DEFAULT 0 NOT NULL, /* Has the peer redeemed their invite yet? */ invite_expires INTEGER, /* The UNIX time that an invited peer can no longer redeem. */ + candidates TEXT, /* A list of additional endpoints that peers can use to connect. */ FOREIGN KEY (cidr_id) REFERENCES cidrs (id) ON UPDATE RESTRICT ON DELETE RESTRICT )"; +pub static COLUMNS: &[&str] = &[ + "id", + "name", + "ip", + "cidr_id", + "public_key", + "endpoint", + "is_admin", + "is_disabled", + "is_redeemed", + "invite_expires", + "candidates", +]; + lazy_static! { /// Regex to match the requirements of hostname(7), needed to have peers also be reachable hostnames. /// Note that the full length also must be maximum 63 characters, which this regex does not check. @@ -71,6 +86,7 @@ impl DatabasePeer { is_disabled, is_redeemed, invite_expires, + candidates, .. } = &contents; log::info!("creating peer {:?}", contents); @@ -96,8 +112,13 @@ impl DatabasePeer { .flatten() .map(|t| t.as_secs()); + let candidates = serde_json::to_string(candidates)?; + conn.execute( - "INSERT INTO peers (name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + &format!( + "INSERT INTO peers ({}) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", + COLUMNS[1..].join(", ") + ), params![ &**name, ip.to_string(), @@ -108,6 +129,7 @@ impl DatabasePeer { is_disabled, is_redeemed, invite_expires, + candidates, ], )?; let id = conn.last_insert_rowid(); @@ -135,17 +157,21 @@ impl DatabasePeer { endpoint: contents.endpoint, is_admin: contents.is_admin, is_disabled: contents.is_disabled, + candidates: contents.candidates, ..self.contents.clone() }; + let new_candidates = serde_json::to_string(&new_contents.candidates)?; conn.execute( "UPDATE peers SET - name = ?1, - endpoint = ?2, - is_admin = ?3, - is_disabled = ?4 - WHERE id = ?5", + name = ?2, + endpoint = ?3, + is_admin = ?4, + is_disabled = ?5, + candidates = ?6 + WHERE id = ?1", params![ + self.id, &*new_contents.name, new_contents .endpoint @@ -153,7 +179,7 @@ impl DatabasePeer { .map(|endpoint| endpoint.to_string()), new_contents.is_admin, new_contents.is_disabled, - self.id, + new_candidates, ], )?; @@ -198,11 +224,11 @@ impl DatabasePeer { let name = row .get::<_, String>(1)? .parse() - .map_err(|_| rusqlite::Error::ExecuteReturnedResults)?; + .map_err(|_| rusqlite::Error::InvalidColumnType(1, "hostname".into(), Type::Text))?; let ip: IpAddr = row .get::<_, String>(2)? .parse() - .map_err(|_| rusqlite::Error::ExecuteReturnedResults)?; + .map_err(|_| rusqlite::Error::InvalidColumnType(2, "ip".into(), Type::Text))?; let cidr_id = row.get(3)?; let public_key = row.get(4)?; let endpoint = row @@ -214,6 +240,10 @@ impl DatabasePeer { let invite_expires = row .get::<_, Option>(9)? .map(|unixtime| SystemTime::UNIX_EPOCH + Duration::from_secs(unixtime)); + let candidates_str: String = row.get(10)?; + let candidates: Vec = serde_json::from_str(&candidates_str).map_err(|_| { + rusqlite::Error::InvalidColumnType(10, "candidates (json)".into(), Type::Text) + })?; let persistent_keepalive_interval = Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS); @@ -230,6 +260,7 @@ impl DatabasePeer { is_disabled, is_redeemed, invite_expires, + candidates, }, } .into()) @@ -237,10 +268,7 @@ impl DatabasePeer { pub fn get(conn: &Connection, id: i64) -> Result { let result = conn.query_row( - "SELECT - id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires - FROM peers - WHERE id = ?1", + &format!("SELECT {} FROM peers WHERE id = ?1", COLUMNS.join(", ")), params![id], Self::from_row, )?; @@ -250,10 +278,7 @@ impl DatabasePeer { pub fn get_from_ip(conn: &Connection, ip: IpAddr) -> Result { let result = conn.query_row( - "SELECT - id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires - FROM peers - WHERE ip = ?1", + &format!("SELECT {} FROM peers WHERE ip = ?1", COLUMNS.join(", ")), params![ip.to_string()], Self::from_row, )?; @@ -271,7 +296,7 @@ impl DatabasePeer { // // NOTE that a forced association is created with the special "infra" CIDR with id 2 (1 being the root). let mut stmt = conn.prepare_cached( - "WITH + &format!("WITH parent_of(id, parent) AS ( SELECT id, parent FROM cidrs WHERE id = ?1 UNION ALL @@ -289,10 +314,12 @@ impl DatabasePeer { UNION SELECT id FROM cidrs, associated_subcidrs WHERE cidrs.parent=associated_subcidrs.cidr_id ) - SELECT DISTINCT peers.id, peers.name, peers.ip, peers.cidr_id, peers.public_key, peers.endpoint, peers.is_admin, peers.is_disabled, peers.is_redeemed, peers.invite_expires + SELECT DISTINCT {} FROM peers JOIN associated_subcidrs ON peers.cidr_id=associated_subcidrs.cidr_id WHERE peers.is_disabled = 0 AND peers.is_redeemed = 1;", + COLUMNS.iter().map(|col| format!("peers.{}", col)).collect::>().join(", ") + ), )?; let peers = stmt .query_map(params![self.cidr_id], Self::from_row)? @@ -301,9 +328,7 @@ impl DatabasePeer { } pub fn list(conn: &Connection) -> Result, ServerError> { - let mut stmt = conn.prepare_cached( - "SELECT id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires FROM peers", - )?; + let mut stmt = conn.prepare_cached(&format!("SELECT {} FROM peers", COLUMNS.join(", ")))?; let peer_iter = stmt.query_map(params![], Self::from_row)?; Ok(peer_iter.collect::>()?) diff --git a/server/src/initialize.rs b/server/src/initialize.rs index e440071..3695475 100644 --- a/server/src/initialize.rs +++ b/server/src/initialize.rs @@ -17,6 +17,7 @@ fn create_database>( conn.execute(db::association::CREATE_TABLE_SQL, params![])?; conn.execute(db::cidr::CREATE_TABLE_SQL, params![])?; conn.pragma_update(None, "user_version", &db::CURRENT_VERSION)?; + log::debug!("set database version to db::CURRENT_VERSION"); Ok(conn) } @@ -89,6 +90,7 @@ fn populate_database(conn: &Connection, db_init_data: DbInitData) -> Result<(), is_redeemed: true, persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS), invite_expires: None, + candidates: vec![], }, ) .map_err(|_| anyhow!("failed to create innernet peer."))?; diff --git a/server/src/main.rs b/server/src/main.rs index 5157ffe..2031f8f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -476,9 +476,11 @@ async fn serve( network: NetworkOpt, ) -> Result<(), Error> { let config = ConfigFile::from_file(conf.config_path(&interface))?; + log::debug!("opening database connection..."); let conn = open_database_connection(&interface, conf)?; let peers = DatabasePeer::list(&conn)?; + log::debug!("peers listed..."); let peer_configs = peers .iter() .map(|peer| peer.deref().into()) diff --git a/server/src/test.rs b/server/src/test.rs index 0f1b3e9..7d90e03 100644 --- a/server/src/test.rs +++ b/server/src/test.rs @@ -231,6 +231,7 @@ pub fn peer_contents( is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }) } diff --git a/shared/Cargo.toml b/shared/Cargo.toml index 84048fa..169d25f 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -29,3 +29,6 @@ netlink-sys = "0.7" netlink-packet-core = "0.2" netlink-packet-route = "0.7" wgctrl-sys = { path = "../wgctrl-sys" } + +[target.'cfg(target_os = "macos")'.dependencies] +nix = "0.22" diff --git a/shared/src/netlink.rs b/shared/src/netlink.rs index de03ea5..95a268b 100644 --- a/shared/src/netlink.rs +++ b/shared/src/netlink.rs @@ -3,11 +3,14 @@ use netlink_packet_core::{ NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, }; use netlink_packet_route::{ - address, constants::*, link, route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, - RouteHeader, RouteMessage, RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN, + address, + constants::*, + link::{self, nlas::State}, + route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, RouteHeader, RouteMessage, + RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN, }; use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; -use std::io; +use std::{io, net::IpAddr}; use wgctrl::InterfaceName; fn if_nametoindex(interface: &InterfaceName) -> Result { @@ -23,7 +26,7 @@ fn if_nametoindex(interface: &InterfaceName) -> Result { fn netlink_call( message: RtnlMessage, flags: Option, -) -> Result, io::Error> { +) -> Result>, io::Error> { let mut req = NetlinkMessage::from(message); req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.finalize(); @@ -32,10 +35,10 @@ fn netlink_call( let len = req.buffer_len(); log::debug!("netlink request: {:?}", req); - let socket = Socket::new(NETLINK_ROUTE).unwrap(); + let socket = Socket::new(NETLINK_ROUTE)?; let kernel_addr = SocketAddr::new(0, 0); socket.connect(&kernel_addr)?; - let n_sent = socket.send(&buf[..len], 0).unwrap(); + let n_sent = socket.send(&buf[..len], 0)?; if n_sent != len { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -43,14 +46,31 @@ fn netlink_call( )); } - let n_received = socket.recv(&mut buf[..], 0).unwrap(); - let response = NetlinkMessage::::deserialize(&buf[..n_received]) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - log::trace!("netlink response: {:?}", response); - if let NetlinkPayload::Error(e) = response.payload { - return Err(e.to_io()); + let mut responses = vec![]; + loop { + let n_received = socket.recv(&mut buf[..], 0)?; + let mut offset = 0; + loop { + let bytes = &buf[offset..]; + let response = NetlinkMessage::::deserialize(bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + responses.push(response.clone()); + log::trace!("netlink response: {:?}", response); + match response.payload { + // We've parsed all parts of the response and can leave the loop. + NetlinkPayload::Ack(_) | NetlinkPayload::Done => return Ok(responses), + NetlinkPayload::Error(e) => return Err(e.into()), + _ => {}, + } + offset += response.header.length as usize; + if offset == n_received || response.header.length == 0 { + // We've fully parsed the datagram, but there may be further datagrams + // with additional netlink response parts. + log::debug!("breaking inner loop"); + break; + } + } } - Ok(response) } pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> { @@ -127,3 +147,87 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result Err(e), } } + +fn get_links() -> Result, io::Error> { + let link_responses = netlink_call( + RtnlMessage::GetLink(LinkMessage::default()), + Some(NLM_F_DUMP | NLM_F_REQUEST), + )?; + let links = link_responses + .into_iter() + // Filter out non-link messages + .filter_map(|response| match response { + NetlinkMessage { + payload: NetlinkPayload::InnerMessage(RtnlMessage::NewLink(link)), + .. + } => Some(link), + _ => None, + }) + // Filter out loopback links + .filter_map(|link| if link.header.flags & IFF_LOOPBACK == 0 { + Some(link.nlas) + } else { + None + }) + // Find and filter out addresses for interfaces + .filter(|nlas| nlas.iter().any(|nla| nla == &link::nlas::Nla::OperState(State::Up))) + .filter_map(|nlas| nlas.iter().find_map(|nla| match nla { + link::nlas::Nla::IfName(name) => Some(name.clone()), + _ => None, + })) + .collect::>(); + + Ok(links) +} + +pub fn get_local_addrs() -> Result, io::Error> { + let links = get_links()?; + let addr_responses = netlink_call( + RtnlMessage::GetAddress(AddressMessage::default()), + Some(NLM_F_DUMP | NLM_F_REQUEST), + )?; + let addrs = addr_responses + .into_iter() + // Filter out non-link messages + .filter_map(|response| match response { + NetlinkMessage { + payload: NetlinkPayload::InnerMessage(RtnlMessage::NewAddress(addr)), + .. + } => Some(addr), + _ => None, + }) + // Filter out non-global-scoped addresses + .filter_map(|link| if link.header.scope == RT_SCOPE_UNIVERSE { + Some(link.nlas) + } else { + None + }) + // Only select addresses for helpful links + .filter(|nlas| nlas.iter().any(|nla| matches!(nla, address::nlas::Nla::Label(label) if links.contains(label)))) + .filter_map(|nlas| nlas.iter().find_map(|nla| match nla { + address::nlas::Nla::Address(name) if name.len() == 4 => { + let mut addr = [0u8; 4]; + addr.copy_from_slice(name); + Some(IpAddr::V4(addr.into())) + }, + address::nlas::Nla::Address(name) if name.len() == 16 => { + let mut addr = [0u8; 16]; + addr.copy_from_slice(name); + Some(IpAddr::V6(addr.into())) + }, + _ => None, + })) + .collect::>(); + Ok(addrs) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_local_addrs() { + let addrs = get_local_addrs().unwrap(); + println!("{:?}", addrs); + } +} diff --git a/shared/src/prompts.rs b/shared/src/prompts.rs index a3e4d80..27d3d61 100644 --- a/shared/src/prompts.rs +++ b/shared/src/prompts.rs @@ -285,6 +285,7 @@ pub fn add_peer( is_redeemed: false, persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS), invite_expires: Some(SystemTime::now() + invite_expires.into()), + candidates: vec![], }; Ok( diff --git a/shared/src/types.rs b/shared/src/types.rs index b458008..b4ce30f 100644 --- a/shared/src/types.rs +++ b/shared/src/types.rs @@ -7,7 +7,7 @@ use std::{ fmt::{self, Display, Formatter}, io, net::{IpAddr, SocketAddr, ToSocketAddrs}, - ops::Deref, + ops::{Deref, DerefMut}, path::Path, str::FromStr, time::{Duration, SystemTime}, @@ -20,6 +20,8 @@ use wgctrl::{ PeerInfo, }; +use crate::wg::PeerInfoExt; + #[derive(Debug, Clone)] pub struct Interface { name: InterfaceName, @@ -408,6 +410,8 @@ pub struct PeerContents { pub is_disabled: bool, pub is_redeemed: bool, pub invite_expires: Option, + #[serde(default)] + pub candidates: Vec, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] @@ -426,6 +430,12 @@ impl Deref for Peer { } } +impl DerefMut for Peer { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.contents + } +} + impl Display for Peer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{} ({})", &self.name, &self.public_key) @@ -497,19 +507,6 @@ impl<'a> PeerDiff<'a> { } } - /// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this - /// as a heuristic for "currentness" without relying on heavier things like ICMP. - pub fn peer_recently_connected(peer: &Option<&PeerInfo>) -> bool { - const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); - - let last_handshake = peer - .and_then(|p| p.stats.last_handshake_time) - .and_then(|t| t.elapsed().ok()) - .unwrap_or_else(|| SystemTime::UNIX_EPOCH.elapsed().unwrap()); - - last_handshake <= REJECT_AFTER_TIME - } - pub fn public_key(&self) -> &Key { self.builder.public_key() } @@ -570,7 +567,10 @@ impl<'a> PeerDiff<'a> { } // We won't update the endpoint if there's already a stable connection. - if !Self::peer_recently_connected(&old_info) { + if !old_info + .map(|info| info.is_recently_connected()) + .unwrap_or_default() + { let resolved = new.endpoint.as_ref().and_then(|e| e.resolve().ok()); if let Some(addr) = resolved { if old.is_none() || matches!(old, Some(old) if old.endpoint != resolved) { @@ -772,6 +772,7 @@ mod tests { is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }, }; let builder = @@ -806,6 +807,7 @@ mod tests { is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }, }; let builder = @@ -840,6 +842,7 @@ mod tests { is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }, }; let builder = diff --git a/shared/src/wg.rs b/shared/src/wg.rs index 521b783..b04398c 100644 --- a/shared/src/wg.rs +++ b/shared/src/wg.rs @@ -1,10 +1,11 @@ -use crate::{Error, IoErrorContext, NetworkOpt}; +use crate::{Error, IoErrorContext, NetworkOpt, Peer, PeerDiff}; use ipnetwork::IpNetwork; use std::{ io, net::{IpAddr, SocketAddr}, + time::Duration, }; -use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, PeerConfigBuilder}; +use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder, PeerInfo}; #[cfg(target_os = "macos")] fn cmd(bin: &str, args: &[&str]) -> Result { @@ -164,3 +165,97 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result Result, io::Error> { + use nix::{net::if_::InterfaceFlags, sys::socket::SockAddr}; + + let addrs = nix::ifaddrs::getifaddrs()? + .inspect(|addr| println!("{:?}", addr)) + .filter(|addr| { + addr.flags.contains(InterfaceFlags::IFF_UP) + && !addr.flags.intersects( + InterfaceFlags::IFF_LOOPBACK + | InterfaceFlags::IFF_POINTOPOINT + | InterfaceFlags::IFF_PROMISC, + ) + }) + .filter_map(|addr| match addr.address { + Some(SockAddr::Inet(addr)) if addr.to_std().is_ipv4() => Some(addr.to_std().ip()), + _ => None, + }) + .collect::>(); + + Ok(addrs) +} + +#[cfg(target_os = "linux")] +pub use super::netlink::get_local_addrs; + +pub trait DeviceExt { + /// Diff the output of a wgctrl device with a list of server-reported peers. + fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec>; + + // /// Get a peer by their public key, a helper function. + fn get_peer(&self, public_key: &str) -> Option<&PeerInfo>; +} + +impl DeviceExt for Device { + fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec> { + let interface_public_key = self + .public_key + .as_ref() + .map(|k| k.to_base64()) + .unwrap_or_default(); + let existing_peers = &self.peers; + + // Match existing peers (by pubkey) to new peer information from the server. + let modifications = peers.iter().filter_map(|peer| { + if peer.is_disabled || peer.public_key == interface_public_key { + None + } else { + let existing_peer = existing_peers + .iter() + .find(|p| p.config.public_key.to_base64() == peer.public_key); + PeerDiff::new(existing_peer, Some(peer)).unwrap() + } + }); + + // Remove any peers on the interface that aren't in the server's peer list any more. + let removals = existing_peers.iter().filter_map(|existing| { + let public_key = existing.config.public_key.to_base64(); + if peers.iter().any(|p| p.public_key == public_key) { + None + } else { + PeerDiff::new(Some(existing), None).unwrap() + } + }); + + modifications.chain(removals).collect::>() + } + + fn get_peer(&self, public_key: &str) -> Option<&PeerInfo> { + Key::from_base64(public_key) + .ok() + .and_then(|key| self.peers.iter().find(|peer| peer.config.public_key == key)) + } +} + +pub trait PeerInfoExt { + /// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this + /// as a heuristic for "currentness" without relying on heavier things like ICMP. + fn is_recently_connected(&self) -> bool; +} +impl PeerInfoExt for PeerInfo { + fn is_recently_connected(&self) -> bool { + const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); + + let last_handshake = self + .stats + .last_handshake_time + .and_then(|t| t.elapsed().ok()) + .unwrap_or(Duration::MAX); + + last_handshake <= REJECT_AFTER_TIME + } +}