diff --git a/Cargo.lock b/Cargo.lock index 71b8636..87100ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,9 +525,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.98" +version = "0.2.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" +checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" [[package]] name = "libsqlite3-sys" @@ -569,6 +569,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" +[[package]] +name = "memoffset" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9" +dependencies = [ + "autocfg", +] + [[package]] name = "mio" version = "0.7.13" @@ -639,6 +648,19 @@ dependencies = [ "log", ] +[[package]] +name = "nix" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7555d6c7164cc913be1ce7f95cbecdabda61eb2ccd89008524af306fb7f5031" +dependencies = [ + "bitflags", + "cc", + "cfg-if", + "libc", + "memoffset", +] + [[package]] name = "nom" version = "6.1.2" @@ -999,6 +1021,7 @@ dependencies = [ "netlink-packet-core", "netlink-packet-route", "netlink-sys", + "nix", "publicip", "regex", "serde", diff --git a/client/src/data_store.rs b/client/src/data_store.rs index df582a2..1736156 100644 --- a/client/src/data_store.rs +++ b/client/src/data_store.rs @@ -80,7 +80,7 @@ impl DataStore { /// /// Note, however, that this does not prevent a compromised server from adding a new /// peer under its control, of course. - pub fn update_peers(&mut self, current_peers: Vec) -> Result<(), Error> { + pub fn update_peers(&mut self, current_peers: &[Peer]) -> Result<(), Error> { let peers = match &mut self.contents { Contents::V1 { ref mut peers, .. } => peers, }; @@ -149,6 +149,7 @@ mod tests { is_redeemed: true, persistent_keepalive_interval: None, invite_expires: None, + candidates: vec![], } }]; static ref BASE_CIDRS: Vec = vec![Cidr { @@ -167,7 +168,7 @@ mod tests { assert_eq!(0, store.peers().len()); assert_eq!(0, store.cidrs().len()); - store.update_peers(BASE_PEERS.to_owned()).unwrap(); + store.update_peers(&BASE_PEERS).unwrap(); store.set_cidrs(BASE_CIDRS.to_owned()); store.write().unwrap(); } @@ -189,13 +190,13 @@ mod tests { DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap(); // Should work, since peer is unmodified. - store.update_peers(BASE_PEERS.clone()).unwrap(); + store.update_peers(&BASE_PEERS).unwrap(); let mut modified = BASE_PEERS.clone(); modified[0].contents.public_key = "foo".to_string(); // Should NOT work, since peer is unmodified. - assert!(store.update_peers(modified).is_err()); + assert!(store.update_peers(&modified).is_err()); } #[test] @@ -206,7 +207,7 @@ mod tests { DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap(); // Should work, since peer is unmodified. - store.update_peers(vec![]).unwrap(); + store.update_peers(&[]).unwrap(); let new_peers = BASE_PEERS .iter() .cloned() diff --git a/client/src/main.rs b/client/src/main.rs index e411883..bc67f86 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -4,13 +4,17 @@ use dialoguer::{Confirm, Input}; use hostsfile::HostsBuilder; use indoc::eprintdoc; use shared::{ - interface_config::InterfaceConfig, prompts, AddAssociationOpts, AddCidrOpts, AddPeerOpts, - Association, AssociationContents, Cidr, CidrTree, DeleteCidrOpts, EndpointContents, - InstallOpts, Interface, IoErrorContext, NetworkOpt, Peer, PeerDiff, RedeemContents, - RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR, REDEEM_TRANSITION_WAIT, + interface_config::InterfaceConfig, + prompts, + wg::{DeviceExt, PeerInfoExt}, + AddAssociationOpts, AddCidrOpts, AddPeerOpts, Association, AssociationContents, Cidr, CidrTree, + DeleteCidrOpts, Endpoint, EndpointContents, InstallOpts, Interface, IoErrorContext, NetworkOpt, + Peer, RedeemContents, RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR, + REDEEM_TRANSITION_WAIT, }; use std::{ fmt, io, + net::SocketAddr, path::{Path, PathBuf}, thread, time::Duration, @@ -19,9 +23,11 @@ use structopt::{clap::AppSettings, StructOpt}; use wgctrl::{Device, DeviceUpdate, InterfaceName, PeerConfigBuilder, PeerInfo}; mod data_store; +mod nat; mod util; use data_store::DataStore; +use nat::NatTraverse; use shared::{wg, Error}; use util::{human_duration, human_size, Api}; @@ -484,70 +490,13 @@ fn fetch( let mut store = DataStore::open_or_create(interface)?; let State { peers, cidrs } = Api::new(&config.server).http("GET", "/user/state")?; - let device_info = Device::get(interface, network.backend).with_str(interface.as_str_lossy())?; - let interface_public_key = device_info - .public_key - .as_ref() - .map(|k| k.to_base64()) - .unwrap_or_default(); - let existing_peers = &device_info.peers; - - // Match existing peers (by pubkey) to new peer information from the server. - let modifications = peers.iter().filter_map(|peer| { - if peer.is_disabled || peer.public_key == interface_public_key { - None - } else { - let existing_peer = existing_peers - .iter() - .find(|p| p.config.public_key.to_base64() == peer.public_key); - PeerDiff::new(existing_peer, Some(peer)).unwrap() - } - }); - - // Remove any peers on the interface that aren't in the server's peer list any more. - let removals = existing_peers.iter().filter_map(|existing| { - let public_key = existing.config.public_key.to_base64(); - if peers.iter().any(|p| p.public_key == public_key) { - None - } else { - PeerDiff::new(Some(existing), None).unwrap() - } - }); + let device = Device::get(interface, network.backend)?; + let modifications = device.diff(&peers); let updates = modifications - .chain(removals) - .inspect(|diff| { - let public_key = diff.public_key().to_base64(); - - let text = match (diff.old, diff.new) { - (None, Some(_)) => "added".green(), - (Some(_), Some(_)) => "modified".yellow(), - (Some(_), None) => "removed".red(), - _ => unreachable!("PeerDiff can't be None -> None"), - }; - - // Grab the peer name from either the new data, or the historical data (if the peer is removed). - let peer_hostname = match diff.new { - Some(peer) => Some(peer.name.clone()), - _ => store - .peers() - .iter() - .find(|p| p.public_key == public_key) - .map(|p| p.name.clone()), - }; - let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]"); - - log::info!( - " peer {} ({}...) was {}.", - peer_name.yellow(), - &public_key[..10].dimmed(), - text - ); - - for change in diff.changes() { - log::debug!(" {}", change); - } - }) + .iter() + .inspect(|diff| util::print_peer_diff(&store, diff)) + .cloned() .map(PeerConfigBuilder::from) .collect::>(); @@ -566,10 +515,44 @@ fn fetch( } else { log::info!("{}", "peers are already up to date.".green()); } + store.set_cidrs(cidrs); - store.update_peers(peers)?; + store.update_peers(&peers)?; store.write().with_str(interface.to_string())?; + let candidates = wg::get_local_addrs()? + .into_iter() + .map(|addr| SocketAddr::from((addr, device.listen_port.unwrap_or(51820))).into()) + .take(10) + .collect::>(); + log::info!( + "reporting {} interface address{} as NAT traversal candidates...", + candidates.len(), + if candidates.len() == 1 { "" } else { "es" } + ); + log::debug!("candidates: {:?}", candidates); + match Api::new(&config.server).http_form::<_, ()>("PUT", "/user/candidates", &candidates) { + Err(ureq::Error::Status(404, _)) => { + log::warn!("your network is using an old version of innernet-server that doesn't support NAT traversal candidate reporting.") + }, + Err(e) => return Err(e.into()), + _ => {}, + } + + log::debug!("viable ICE candidates: {:?}", candidates); + + let mut nat_traverse = NatTraverse::new(interface, network.backend, &modifications)?; + loop { + if nat_traverse.is_finished() { + break; + } + log::info!( + "Attempting to establish connection with {} remaining unconnected peers...", + nat_traverse.remaining() + ); + nat_traverse.step()?; + } + Ok(()) } @@ -992,7 +975,9 @@ fn print_peer(peer: &PeerState, short: bool, level: usize) { let pad = level * 2; let PeerState { peer, info } = peer; if short { - let connected = PeerDiff::peer_recently_connected(info); + let connected = info + .map(|info| !info.is_recently_connected()) + .unwrap_or_default(); println_pad!( pad, diff --git a/client/src/nat.rs b/client/src/nat.rs new file mode 100644 index 0000000..b7d22d5 --- /dev/null +++ b/client/src/nat.rs @@ -0,0 +1,125 @@ +//! ICE-like NAT traversal logic. +//! +//! Doesn't follow the specific ICE protocol, but takes great inspiration from RFC 8445 +//! and applies it to a protocol more specific to innernet. + +use std::time::{Duration, Instant}; + +use anyhow::Error; +use shared::{ + wg::{DeviceExt, PeerInfoExt}, + Endpoint, Peer, PeerDiff, +}; +use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder}; + +const STEP_INTERVAL: Duration = Duration::from_secs(5); + +pub struct NatTraverse<'a> { + interface: &'a InterfaceName, + backend: Backend, + remaining: Vec, +} + +impl<'a> NatTraverse<'a> { + pub fn new(interface: &'a InterfaceName, backend: Backend, diffs: &[PeerDiff]) -> Result { + let mut remaining: Vec<_> = diffs.iter().filter_map(|diff| diff.new).cloned().collect(); + + for peer in &mut remaining { + // Limit reported alternative candidates to 10. + peer.candidates.truncate(10); + + // remove server-reported endpoint from elsewhere in the list if it existed. + let endpoint = peer.endpoint.clone(); + peer.candidates + .retain(|addr| Some(addr) != endpoint.as_ref()); + } + let mut nat_traverse = Self { + interface, + backend, + remaining, + }; + nat_traverse.refresh_remaining()?; + Ok(nat_traverse) + } + + pub fn is_finished(&self) -> bool { + self.remaining.is_empty() + } + + pub fn remaining(&self) -> usize { + self.remaining.len() + } + + /// Refreshes the current state of candidate traversal attempts, returning + /// the peers that have been exhausted of all options (not included are + /// peers that have successfully connected, or peers removed from the interface). + fn refresh_remaining(&mut self) -> Result, Error> { + let device = Device::get(self.interface, self.backend)?; + // Remove connected and missing peers + self.remaining.retain(|peer| { + if let Some(peer_info) = device.get_peer(&peer.public_key) { + let recently_connected = peer_info.is_recently_connected(); + if recently_connected { + log::debug!( + "peer {} removed from NAT traverser (connected!).", + peer.name + ); + } + !recently_connected + } else { + log::debug!( + "peer {} removed from NAT traverser (no longer on interface).", + peer.name + ); + false + } + }); + let (exhausted, remaining): (Vec<_>, Vec<_>) = self + .remaining + .drain(..) + .partition(|peer| peer.candidates.is_empty()); + self.remaining = remaining; + Ok(exhausted) + } + + pub fn step(&mut self) -> Result<(), Error> { + let exhuasted = self.refresh_remaining()?; + + // Reset peer endpoints that had no viable candidates back to the server-reported one, if it exists. + let reset_updates = exhuasted + .into_iter() + .filter_map(|peer| set_endpoint(&peer.public_key, peer.endpoint.as_ref())); + + // Set all peers' endpoints to their next available candidate. + let candidate_updates = self.remaining.iter_mut().filter_map(|peer| { + let endpoint = peer.candidates.pop(); + set_endpoint(&peer.public_key, endpoint.as_ref()) + }); + + let updates: Vec<_> = reset_updates.chain(candidate_updates).collect(); + + DeviceUpdate::new() + .add_peers(&updates) + .apply(self.interface, self.backend)?; + + let start = Instant::now(); + while start.elapsed() < STEP_INTERVAL { + self.refresh_remaining()?; + if self.is_finished() { + log::debug!("NAT traverser is finished!"); + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + Ok(()) + } +} + +/// Return a PeerConfigBuilder if an endpoint exists and resolves successfully. +fn set_endpoint(public_key: &str, endpoint: Option<&Endpoint>) -> Option { + endpoint + .and_then(|endpoint| endpoint.resolve().ok()) + .map(|addr| { + PeerConfigBuilder::new(&Key::from_base64(public_key).unwrap()).set_endpoint(addr) + }) +} diff --git a/client/src/util.rs b/client/src/util.rs index 310fbe8..d853685 100644 --- a/client/src/util.rs +++ b/client/src/util.rs @@ -1,9 +1,9 @@ -use crate::{ClientError, Error}; +use crate::data_store::DataStore; use colored::*; use indoc::eprintdoc; use log::{Level, LevelFilter}; use serde::{de::DeserializeOwned, Serialize}; -use shared::{interface_config::ServerInfo, INNERNET_PUBKEY_HEADER}; +use shared::{interface_config::ServerInfo, PeerDiff, INNERNET_PUBKEY_HEADER}; use std::{io, time::Duration}; use ureq::{Agent, AgentBuilder}; @@ -137,6 +137,39 @@ pub fn permissions_helptext(e: &io::Error) { } } +pub fn print_peer_diff(store: &DataStore, diff: &PeerDiff) { + let public_key = diff.public_key().to_base64(); + + let text = match (diff.old, diff.new) { + (None, Some(_)) => "added".green(), + (Some(_), Some(_)) => "modified".yellow(), + (Some(_), None) => "removed".red(), + _ => unreachable!("PeerDiff can't be None -> None"), + }; + + // Grab the peer name from either the new data, or the historical data (if the peer is removed). + let peer_hostname = match diff.new { + Some(peer) => Some(peer.name.clone()), + None => store + .peers() + .iter() + .find(|p| p.public_key == public_key) + .map(|p| p.name.clone()), + }; + let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]"); + + log::info!( + " peer {} ({}...) was {}.", + peer_name.yellow(), + &public_key[..10].dimmed(), + text + ); + + for change in diff.changes() { + log::debug!(" {}", change); + } +} + pub struct Api<'a> { agent: Agent, server: &'a ServerInfo, @@ -151,7 +184,7 @@ impl<'a> Api<'a> { Self { agent, server } } - pub fn http(&self, verb: &str, endpoint: &str) -> Result { + pub fn http(&self, verb: &str, endpoint: &str) -> Result { self.request::<(), _>(verb, endpoint, None) } @@ -160,7 +193,7 @@ impl<'a> Api<'a> { verb: &str, endpoint: &str, form: S, - ) -> Result { + ) -> Result { self.request(verb, endpoint, Some(form)) } @@ -169,7 +202,7 @@ impl<'a> Api<'a> { verb: &str, endpoint: &str, form: Option, - ) -> Result { + ) -> Result { let request = self .agent .request( @@ -179,7 +212,12 @@ impl<'a> Api<'a> { .set(INNERNET_PUBKEY_HEADER, &self.server.public_key); let response = if let Some(form) = form { - request.send_json(serde_json::to_value(form)?)? + request.send_json(serde_json::to_value(form).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to serialize JSON request: {}", e), + ) + })?)? } else { request.call()? }; @@ -190,10 +228,13 @@ impl<'a> Api<'a> { response = "null".into(); } Ok(serde_json::from_str(&response).map_err(|e| { - ClientError(format!( - "failed to deserialize JSON response from the server: {}, response={}", - e, &response - )) + io::Error::new( + io::ErrorKind::InvalidData, + format!( + "failed to deserialize JSON response from the server: {}, response={}", + e, &response + ), + ) })?) } } diff --git a/server/src/api/user.rs b/server/src/api/user.rs index 40c189c..ffedd76 100644 --- a/server/src/api/user.rs +++ b/server/src/api/user.rs @@ -36,11 +36,20 @@ pub async fn routes( let form = form_body(req).await?; handlers::endpoint(form, session).await }, + (&Method::PUT, Some("candidates")) => { + if !session.user_capable() { + return Err(ServerError::Unauthorized); + } + let form = form_body(req).await?; + handlers::candidates(form, session).await + }, _ => Err(ServerError::NotFound), } } mod handlers { + use shared::Endpoint; + use super::*; /// Get the current state of the network, in the eyes of the current peer. @@ -115,14 +124,29 @@ mod handlers { status_response(StatusCode::NO_CONTENT) } - /// Redeems an invitation. An invitation includes a WireGuard keypair generated by either the server - /// or a peer with admin rights. - /// - /// Redemption is the process of an invitee generating their own keypair and exchanging their temporary - /// key with their permanent one. - /// - /// Until this API endpoint is called, the invited peer will not show up to other peers, and once - /// it is called and succeeds, it cannot be called again. + /// Report any other endpoint candidates that can be tried by peers to connect. + /// Currently limited to 10 candidates max. + pub async fn candidates( + contents: Vec, + session: Session, + ) -> Result, ServerError> { + if contents.len() > 10 { + return status_response(StatusCode::PAYLOAD_TOO_LARGE); + } + let conn = session.context.db.lock(); + let mut selected_peer = DatabasePeer::get(&conn, session.peer.id)?; + selected_peer.update( + &conn, + PeerContents { + candidates: contents, + ..selected_peer.contents.clone() + }, + )?; + + status_response(StatusCode::NO_CONTENT) + } + + /// Force a specific endpoint to be reported by the server. pub async fn endpoint( contents: EndpointContents, session: Session, @@ -148,7 +172,7 @@ mod tests { use super::*; use crate::{db::DatabaseAssociation, test}; use bytes::Buf; - use shared::{AssociationContents, CidrContents, EndpointContents, Error}; + use shared::{AssociationContents, CidrContents, Endpoint, EndpointContents, Error}; #[tokio::test] async fn test_get_state_from_developer1() -> Result<(), Error> { @@ -406,4 +430,36 @@ mod tests { assert_eq!(res.status(), StatusCode::UNAUTHORIZED); Ok(()) } + + #[tokio::test] + async fn test_candidates() -> Result<(), Error> { + let server = test::Server::new()?; + + let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?; + assert_eq!(peer.candidates, vec![]); + + let candidates = vec!["1.1.1.1:51820".parse::().unwrap()]; + assert_eq!( + server + .form_request( + test::DEVELOPER1_PEER_IP, + "PUT", + "/v1/user/candidates", + &candidates + ) + .await + .status(), + StatusCode::NO_CONTENT + ); + + let res = server + .request(test::DEVELOPER1_PEER_IP, "GET", "/v1/user/state") + .await; + + assert_eq!(res.status(), StatusCode::OK); + + let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?; + assert_eq!(peer.candidates, candidates); + Ok(()) + } } diff --git a/server/src/db/mod.rs b/server/src/db/mod.rs index 60a7912..21ae8c9 100644 --- a/server/src/db/mod.rs +++ b/server/src/db/mod.rs @@ -8,11 +8,13 @@ pub use peer::DatabasePeer; use rusqlite::params; const INVITE_EXPIRATION_VERSION: usize = 1; +const ENDPOINT_CANDIDATES_VERSION: usize = 2; -pub const CURRENT_VERSION: usize = INVITE_EXPIRATION_VERSION; +pub const CURRENT_VERSION: usize = ENDPOINT_CANDIDATES_VERSION; pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { let old_version: usize = conn.pragma_query_value(None, "user_version", |r| r.get(0))?; + log::debug!("user_version: {}", old_version); if old_version < INVITE_EXPIRATION_VERSION { conn.execute( @@ -21,8 +23,12 @@ pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> )?; } - conn.pragma_update(None, "user_version", &CURRENT_VERSION)?; + if old_version < ENDPOINT_CANDIDATES_VERSION { + conn.execute("ALTER TABLE peers ADD COLUMN candidates TEXT", params![])?; + } + if old_version != CURRENT_VERSION { + conn.pragma_update(None, "user_version", &CURRENT_VERSION)?; log::info!( "migrated db version from {} to {}", old_version, diff --git a/server/src/db/peer.rs b/server/src/db/peer.rs index 5f729a0..5bf7973 100644 --- a/server/src/db/peer.rs +++ b/server/src/db/peer.rs @@ -2,8 +2,8 @@ use super::DatabaseCidr; use crate::ServerError; use lazy_static::lazy_static; use regex::Regex; -use rusqlite::{params, Connection}; -use shared::{Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS}; +use rusqlite::{params, types::Type, Connection}; +use shared::{Endpoint, Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS}; use std::{ net::IpAddr, ops::{Deref, DerefMut}, @@ -22,12 +22,27 @@ pub static CREATE_TABLE_SQL: &str = "CREATE TABLE peers ( is_disabled INTEGER DEFAULT 0 NOT NULL, /* Is the peer disabled? (peers cannot be deleted) */ is_redeemed INTEGER DEFAULT 0 NOT NULL, /* Has the peer redeemed their invite yet? */ invite_expires INTEGER, /* The UNIX time that an invited peer can no longer redeem. */ + candidates TEXT, /* A list of additional endpoints that peers can use to connect. */ FOREIGN KEY (cidr_id) REFERENCES cidrs (id) ON UPDATE RESTRICT ON DELETE RESTRICT )"; +pub static COLUMNS: &[&str] = &[ + "id", + "name", + "ip", + "cidr_id", + "public_key", + "endpoint", + "is_admin", + "is_disabled", + "is_redeemed", + "invite_expires", + "candidates", +]; + lazy_static! { /// Regex to match the requirements of hostname(7), needed to have peers also be reachable hostnames. /// Note that the full length also must be maximum 63 characters, which this regex does not check. @@ -71,6 +86,7 @@ impl DatabasePeer { is_disabled, is_redeemed, invite_expires, + candidates, .. } = &contents; log::info!("creating peer {:?}", contents); @@ -96,8 +112,13 @@ impl DatabasePeer { .flatten() .map(|t| t.as_secs()); + let candidates = serde_json::to_string(candidates)?; + conn.execute( - "INSERT INTO peers (name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + &format!( + "INSERT INTO peers ({}) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", + COLUMNS[1..].join(", ") + ), params![ &**name, ip.to_string(), @@ -108,6 +129,7 @@ impl DatabasePeer { is_disabled, is_redeemed, invite_expires, + candidates, ], )?; let id = conn.last_insert_rowid(); @@ -135,17 +157,21 @@ impl DatabasePeer { endpoint: contents.endpoint, is_admin: contents.is_admin, is_disabled: contents.is_disabled, + candidates: contents.candidates, ..self.contents.clone() }; + let new_candidates = serde_json::to_string(&new_contents.candidates)?; conn.execute( "UPDATE peers SET - name = ?1, - endpoint = ?2, - is_admin = ?3, - is_disabled = ?4 - WHERE id = ?5", + name = ?2, + endpoint = ?3, + is_admin = ?4, + is_disabled = ?5, + candidates = ?6 + WHERE id = ?1", params![ + self.id, &*new_contents.name, new_contents .endpoint @@ -153,7 +179,7 @@ impl DatabasePeer { .map(|endpoint| endpoint.to_string()), new_contents.is_admin, new_contents.is_disabled, - self.id, + new_candidates, ], )?; @@ -198,11 +224,11 @@ impl DatabasePeer { let name = row .get::<_, String>(1)? .parse() - .map_err(|_| rusqlite::Error::ExecuteReturnedResults)?; + .map_err(|_| rusqlite::Error::InvalidColumnType(1, "hostname".into(), Type::Text))?; let ip: IpAddr = row .get::<_, String>(2)? .parse() - .map_err(|_| rusqlite::Error::ExecuteReturnedResults)?; + .map_err(|_| rusqlite::Error::InvalidColumnType(2, "ip".into(), Type::Text))?; let cidr_id = row.get(3)?; let public_key = row.get(4)?; let endpoint = row @@ -214,6 +240,10 @@ impl DatabasePeer { let invite_expires = row .get::<_, Option>(9)? .map(|unixtime| SystemTime::UNIX_EPOCH + Duration::from_secs(unixtime)); + let candidates_str: String = row.get(10)?; + let candidates: Vec = serde_json::from_str(&candidates_str).map_err(|_| { + rusqlite::Error::InvalidColumnType(10, "candidates (json)".into(), Type::Text) + })?; let persistent_keepalive_interval = Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS); @@ -230,6 +260,7 @@ impl DatabasePeer { is_disabled, is_redeemed, invite_expires, + candidates, }, } .into()) @@ -237,10 +268,7 @@ impl DatabasePeer { pub fn get(conn: &Connection, id: i64) -> Result { let result = conn.query_row( - "SELECT - id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires - FROM peers - WHERE id = ?1", + &format!("SELECT {} FROM peers WHERE id = ?1", COLUMNS.join(", ")), params![id], Self::from_row, )?; @@ -250,10 +278,7 @@ impl DatabasePeer { pub fn get_from_ip(conn: &Connection, ip: IpAddr) -> Result { let result = conn.query_row( - "SELECT - id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires - FROM peers - WHERE ip = ?1", + &format!("SELECT {} FROM peers WHERE ip = ?1", COLUMNS.join(", ")), params![ip.to_string()], Self::from_row, )?; @@ -271,7 +296,7 @@ impl DatabasePeer { // // NOTE that a forced association is created with the special "infra" CIDR with id 2 (1 being the root). let mut stmt = conn.prepare_cached( - "WITH + &format!("WITH parent_of(id, parent) AS ( SELECT id, parent FROM cidrs WHERE id = ?1 UNION ALL @@ -289,10 +314,12 @@ impl DatabasePeer { UNION SELECT id FROM cidrs, associated_subcidrs WHERE cidrs.parent=associated_subcidrs.cidr_id ) - SELECT DISTINCT peers.id, peers.name, peers.ip, peers.cidr_id, peers.public_key, peers.endpoint, peers.is_admin, peers.is_disabled, peers.is_redeemed, peers.invite_expires + SELECT DISTINCT {} FROM peers JOIN associated_subcidrs ON peers.cidr_id=associated_subcidrs.cidr_id WHERE peers.is_disabled = 0 AND peers.is_redeemed = 1;", + COLUMNS.iter().map(|col| format!("peers.{}", col)).collect::>().join(", ") + ), )?; let peers = stmt .query_map(params![self.cidr_id], Self::from_row)? @@ -301,9 +328,7 @@ impl DatabasePeer { } pub fn list(conn: &Connection) -> Result, ServerError> { - let mut stmt = conn.prepare_cached( - "SELECT id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires FROM peers", - )?; + let mut stmt = conn.prepare_cached(&format!("SELECT {} FROM peers", COLUMNS.join(", ")))?; let peer_iter = stmt.query_map(params![], Self::from_row)?; Ok(peer_iter.collect::>()?) diff --git a/server/src/initialize.rs b/server/src/initialize.rs index e440071..3695475 100644 --- a/server/src/initialize.rs +++ b/server/src/initialize.rs @@ -17,6 +17,7 @@ fn create_database>( conn.execute(db::association::CREATE_TABLE_SQL, params![])?; conn.execute(db::cidr::CREATE_TABLE_SQL, params![])?; conn.pragma_update(None, "user_version", &db::CURRENT_VERSION)?; + log::debug!("set database version to db::CURRENT_VERSION"); Ok(conn) } @@ -89,6 +90,7 @@ fn populate_database(conn: &Connection, db_init_data: DbInitData) -> Result<(), is_redeemed: true, persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS), invite_expires: None, + candidates: vec![], }, ) .map_err(|_| anyhow!("failed to create innernet peer."))?; diff --git a/server/src/main.rs b/server/src/main.rs index 5157ffe..2031f8f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -476,9 +476,11 @@ async fn serve( network: NetworkOpt, ) -> Result<(), Error> { let config = ConfigFile::from_file(conf.config_path(&interface))?; + log::debug!("opening database connection..."); let conn = open_database_connection(&interface, conf)?; let peers = DatabasePeer::list(&conn)?; + log::debug!("peers listed..."); let peer_configs = peers .iter() .map(|peer| peer.deref().into()) diff --git a/server/src/test.rs b/server/src/test.rs index 0f1b3e9..7d90e03 100644 --- a/server/src/test.rs +++ b/server/src/test.rs @@ -231,6 +231,7 @@ pub fn peer_contents( is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }) } diff --git a/shared/Cargo.toml b/shared/Cargo.toml index 84048fa..169d25f 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -29,3 +29,6 @@ netlink-sys = "0.7" netlink-packet-core = "0.2" netlink-packet-route = "0.7" wgctrl-sys = { path = "../wgctrl-sys" } + +[target.'cfg(target_os = "macos")'.dependencies] +nix = "0.22" diff --git a/shared/src/netlink.rs b/shared/src/netlink.rs index de03ea5..95a268b 100644 --- a/shared/src/netlink.rs +++ b/shared/src/netlink.rs @@ -3,11 +3,14 @@ use netlink_packet_core::{ NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, }; use netlink_packet_route::{ - address, constants::*, link, route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, - RouteHeader, RouteMessage, RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN, + address, + constants::*, + link::{self, nlas::State}, + route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, RouteHeader, RouteMessage, + RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN, }; use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; -use std::io; +use std::{io, net::IpAddr}; use wgctrl::InterfaceName; fn if_nametoindex(interface: &InterfaceName) -> Result { @@ -23,7 +26,7 @@ fn if_nametoindex(interface: &InterfaceName) -> Result { fn netlink_call( message: RtnlMessage, flags: Option, -) -> Result, io::Error> { +) -> Result>, io::Error> { let mut req = NetlinkMessage::from(message); req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.finalize(); @@ -32,10 +35,10 @@ fn netlink_call( let len = req.buffer_len(); log::debug!("netlink request: {:?}", req); - let socket = Socket::new(NETLINK_ROUTE).unwrap(); + let socket = Socket::new(NETLINK_ROUTE)?; let kernel_addr = SocketAddr::new(0, 0); socket.connect(&kernel_addr)?; - let n_sent = socket.send(&buf[..len], 0).unwrap(); + let n_sent = socket.send(&buf[..len], 0)?; if n_sent != len { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -43,14 +46,31 @@ fn netlink_call( )); } - let n_received = socket.recv(&mut buf[..], 0).unwrap(); - let response = NetlinkMessage::::deserialize(&buf[..n_received]) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - log::trace!("netlink response: {:?}", response); - if let NetlinkPayload::Error(e) = response.payload { - return Err(e.to_io()); + let mut responses = vec![]; + loop { + let n_received = socket.recv(&mut buf[..], 0)?; + let mut offset = 0; + loop { + let bytes = &buf[offset..]; + let response = NetlinkMessage::::deserialize(bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + responses.push(response.clone()); + log::trace!("netlink response: {:?}", response); + match response.payload { + // We've parsed all parts of the response and can leave the loop. + NetlinkPayload::Ack(_) | NetlinkPayload::Done => return Ok(responses), + NetlinkPayload::Error(e) => return Err(e.into()), + _ => {}, + } + offset += response.header.length as usize; + if offset == n_received || response.header.length == 0 { + // We've fully parsed the datagram, but there may be further datagrams + // with additional netlink response parts. + log::debug!("breaking inner loop"); + break; + } + } } - Ok(response) } pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> { @@ -127,3 +147,87 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result Err(e), } } + +fn get_links() -> Result, io::Error> { + let link_responses = netlink_call( + RtnlMessage::GetLink(LinkMessage::default()), + Some(NLM_F_DUMP | NLM_F_REQUEST), + )?; + let links = link_responses + .into_iter() + // Filter out non-link messages + .filter_map(|response| match response { + NetlinkMessage { + payload: NetlinkPayload::InnerMessage(RtnlMessage::NewLink(link)), + .. + } => Some(link), + _ => None, + }) + // Filter out loopback links + .filter_map(|link| if link.header.flags & IFF_LOOPBACK == 0 { + Some(link.nlas) + } else { + None + }) + // Find and filter out addresses for interfaces + .filter(|nlas| nlas.iter().any(|nla| nla == &link::nlas::Nla::OperState(State::Up))) + .filter_map(|nlas| nlas.iter().find_map(|nla| match nla { + link::nlas::Nla::IfName(name) => Some(name.clone()), + _ => None, + })) + .collect::>(); + + Ok(links) +} + +pub fn get_local_addrs() -> Result, io::Error> { + let links = get_links()?; + let addr_responses = netlink_call( + RtnlMessage::GetAddress(AddressMessage::default()), + Some(NLM_F_DUMP | NLM_F_REQUEST), + )?; + let addrs = addr_responses + .into_iter() + // Filter out non-link messages + .filter_map(|response| match response { + NetlinkMessage { + payload: NetlinkPayload::InnerMessage(RtnlMessage::NewAddress(addr)), + .. + } => Some(addr), + _ => None, + }) + // Filter out non-global-scoped addresses + .filter_map(|link| if link.header.scope == RT_SCOPE_UNIVERSE { + Some(link.nlas) + } else { + None + }) + // Only select addresses for helpful links + .filter(|nlas| nlas.iter().any(|nla| matches!(nla, address::nlas::Nla::Label(label) if links.contains(label)))) + .filter_map(|nlas| nlas.iter().find_map(|nla| match nla { + address::nlas::Nla::Address(name) if name.len() == 4 => { + let mut addr = [0u8; 4]; + addr.copy_from_slice(name); + Some(IpAddr::V4(addr.into())) + }, + address::nlas::Nla::Address(name) if name.len() == 16 => { + let mut addr = [0u8; 16]; + addr.copy_from_slice(name); + Some(IpAddr::V6(addr.into())) + }, + _ => None, + })) + .collect::>(); + Ok(addrs) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_local_addrs() { + let addrs = get_local_addrs().unwrap(); + println!("{:?}", addrs); + } +} diff --git a/shared/src/prompts.rs b/shared/src/prompts.rs index a3e4d80..27d3d61 100644 --- a/shared/src/prompts.rs +++ b/shared/src/prompts.rs @@ -285,6 +285,7 @@ pub fn add_peer( is_redeemed: false, persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS), invite_expires: Some(SystemTime::now() + invite_expires.into()), + candidates: vec![], }; Ok( diff --git a/shared/src/types.rs b/shared/src/types.rs index b458008..b4ce30f 100644 --- a/shared/src/types.rs +++ b/shared/src/types.rs @@ -7,7 +7,7 @@ use std::{ fmt::{self, Display, Formatter}, io, net::{IpAddr, SocketAddr, ToSocketAddrs}, - ops::Deref, + ops::{Deref, DerefMut}, path::Path, str::FromStr, time::{Duration, SystemTime}, @@ -20,6 +20,8 @@ use wgctrl::{ PeerInfo, }; +use crate::wg::PeerInfoExt; + #[derive(Debug, Clone)] pub struct Interface { name: InterfaceName, @@ -408,6 +410,8 @@ pub struct PeerContents { pub is_disabled: bool, pub is_redeemed: bool, pub invite_expires: Option, + #[serde(default)] + pub candidates: Vec, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] @@ -426,6 +430,12 @@ impl Deref for Peer { } } +impl DerefMut for Peer { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.contents + } +} + impl Display for Peer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{} ({})", &self.name, &self.public_key) @@ -497,19 +507,6 @@ impl<'a> PeerDiff<'a> { } } - /// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this - /// as a heuristic for "currentness" without relying on heavier things like ICMP. - pub fn peer_recently_connected(peer: &Option<&PeerInfo>) -> bool { - const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); - - let last_handshake = peer - .and_then(|p| p.stats.last_handshake_time) - .and_then(|t| t.elapsed().ok()) - .unwrap_or_else(|| SystemTime::UNIX_EPOCH.elapsed().unwrap()); - - last_handshake <= REJECT_AFTER_TIME - } - pub fn public_key(&self) -> &Key { self.builder.public_key() } @@ -570,7 +567,10 @@ impl<'a> PeerDiff<'a> { } // We won't update the endpoint if there's already a stable connection. - if !Self::peer_recently_connected(&old_info) { + if !old_info + .map(|info| info.is_recently_connected()) + .unwrap_or_default() + { let resolved = new.endpoint.as_ref().and_then(|e| e.resolve().ok()); if let Some(addr) = resolved { if old.is_none() || matches!(old, Some(old) if old.endpoint != resolved) { @@ -772,6 +772,7 @@ mod tests { is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }, }; let builder = @@ -806,6 +807,7 @@ mod tests { is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }, }; let builder = @@ -840,6 +842,7 @@ mod tests { is_disabled: false, is_redeemed: true, invite_expires: None, + candidates: vec![], }, }; let builder = diff --git a/shared/src/wg.rs b/shared/src/wg.rs index 521b783..b04398c 100644 --- a/shared/src/wg.rs +++ b/shared/src/wg.rs @@ -1,10 +1,11 @@ -use crate::{Error, IoErrorContext, NetworkOpt}; +use crate::{Error, IoErrorContext, NetworkOpt, Peer, PeerDiff}; use ipnetwork::IpNetwork; use std::{ io, net::{IpAddr, SocketAddr}, + time::Duration, }; -use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, PeerConfigBuilder}; +use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder, PeerInfo}; #[cfg(target_os = "macos")] fn cmd(bin: &str, args: &[&str]) -> Result { @@ -164,3 +165,97 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result Result, io::Error> { + use nix::{net::if_::InterfaceFlags, sys::socket::SockAddr}; + + let addrs = nix::ifaddrs::getifaddrs()? + .inspect(|addr| println!("{:?}", addr)) + .filter(|addr| { + addr.flags.contains(InterfaceFlags::IFF_UP) + && !addr.flags.intersects( + InterfaceFlags::IFF_LOOPBACK + | InterfaceFlags::IFF_POINTOPOINT + | InterfaceFlags::IFF_PROMISC, + ) + }) + .filter_map(|addr| match addr.address { + Some(SockAddr::Inet(addr)) if addr.to_std().is_ipv4() => Some(addr.to_std().ip()), + _ => None, + }) + .collect::>(); + + Ok(addrs) +} + +#[cfg(target_os = "linux")] +pub use super::netlink::get_local_addrs; + +pub trait DeviceExt { + /// Diff the output of a wgctrl device with a list of server-reported peers. + fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec>; + + // /// Get a peer by their public key, a helper function. + fn get_peer(&self, public_key: &str) -> Option<&PeerInfo>; +} + +impl DeviceExt for Device { + fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec> { + let interface_public_key = self + .public_key + .as_ref() + .map(|k| k.to_base64()) + .unwrap_or_default(); + let existing_peers = &self.peers; + + // Match existing peers (by pubkey) to new peer information from the server. + let modifications = peers.iter().filter_map(|peer| { + if peer.is_disabled || peer.public_key == interface_public_key { + None + } else { + let existing_peer = existing_peers + .iter() + .find(|p| p.config.public_key.to_base64() == peer.public_key); + PeerDiff::new(existing_peer, Some(peer)).unwrap() + } + }); + + // Remove any peers on the interface that aren't in the server's peer list any more. + let removals = existing_peers.iter().filter_map(|existing| { + let public_key = existing.config.public_key.to_base64(); + if peers.iter().any(|p| p.public_key == public_key) { + None + } else { + PeerDiff::new(Some(existing), None).unwrap() + } + }); + + modifications.chain(removals).collect::>() + } + + fn get_peer(&self, public_key: &str) -> Option<&PeerInfo> { + Key::from_base64(public_key) + .ok() + .and_then(|key| self.peers.iter().find(|peer| peer.config.public_key == key)) + } +} + +pub trait PeerInfoExt { + /// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this + /// as a heuristic for "currentness" without relying on heavier things like ICMP. + fn is_recently_connected(&self) -> bool; +} +impl PeerInfoExt for PeerInfo { + fn is_recently_connected(&self) -> bool { + const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); + + let last_handshake = self + .stats + .last_handshake_time + .and_then(|t| t.elapsed().ok()) + .unwrap_or(Duration::MAX); + + last_handshake <= REJECT_AFTER_TIME + } +}