NAT traversal: ICE-esque candidate selection (#134)

This change adds the ability for peers to report additional candidate endpoints for other peers to attempt connections with outside of the endpoint reported by the coordinating server.

While not a complete solution to the full spectrum of NAT traversal issues (TURN-esque proxying is still notably missing), it allows peers within the same NAT to connect to each other via their LAN addresses, which is a win nonetheless. In the future, more advanced candidate discovery could be used to punch through additional types of NAT cone types as well.

Co-authored-by: Matěj Laitl <matej@laitl.cz>
pull/136/head
Jake McGinty 2021-09-01 18:58:46 +09:00 committed by GitHub
parent fd06b8054d
commit 8903604caa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 623 additions and 150 deletions

27
Cargo.lock generated
View File

@ -525,9 +525,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.98" version = "0.2.101"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21"
[[package]] [[package]]
name = "libsqlite3-sys" name = "libsqlite3-sys"
@ -569,6 +569,15 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
[[package]]
name = "memoffset"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.7.13" version = "0.7.13"
@ -639,6 +648,19 @@ dependencies = [
"log", "log",
] ]
[[package]]
name = "nix"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7555d6c7164cc913be1ce7f95cbecdabda61eb2ccd89008524af306fb7f5031"
dependencies = [
"bitflags",
"cc",
"cfg-if",
"libc",
"memoffset",
]
[[package]] [[package]]
name = "nom" name = "nom"
version = "6.1.2" version = "6.1.2"
@ -999,6 +1021,7 @@ dependencies = [
"netlink-packet-core", "netlink-packet-core",
"netlink-packet-route", "netlink-packet-route",
"netlink-sys", "netlink-sys",
"nix",
"publicip", "publicip",
"regex", "regex",
"serde", "serde",

View File

@ -80,7 +80,7 @@ impl DataStore {
/// ///
/// Note, however, that this does not prevent a compromised server from adding a new /// Note, however, that this does not prevent a compromised server from adding a new
/// peer under its control, of course. /// peer under its control, of course.
pub fn update_peers(&mut self, current_peers: Vec<Peer>) -> Result<(), Error> { pub fn update_peers(&mut self, current_peers: &[Peer]) -> Result<(), Error> {
let peers = match &mut self.contents { let peers = match &mut self.contents {
Contents::V1 { ref mut peers, .. } => peers, Contents::V1 { ref mut peers, .. } => peers,
}; };
@ -149,6 +149,7 @@ mod tests {
is_redeemed: true, is_redeemed: true,
persistent_keepalive_interval: None, persistent_keepalive_interval: None,
invite_expires: None, invite_expires: None,
candidates: vec![],
} }
}]; }];
static ref BASE_CIDRS: Vec<Cidr> = vec![Cidr { static ref BASE_CIDRS: Vec<Cidr> = vec![Cidr {
@ -167,7 +168,7 @@ mod tests {
assert_eq!(0, store.peers().len()); assert_eq!(0, store.peers().len());
assert_eq!(0, store.cidrs().len()); assert_eq!(0, store.cidrs().len());
store.update_peers(BASE_PEERS.to_owned()).unwrap(); store.update_peers(&BASE_PEERS).unwrap();
store.set_cidrs(BASE_CIDRS.to_owned()); store.set_cidrs(BASE_CIDRS.to_owned());
store.write().unwrap(); store.write().unwrap();
} }
@ -189,13 +190,13 @@ mod tests {
DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap(); DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap();
// Should work, since peer is unmodified. // Should work, since peer is unmodified.
store.update_peers(BASE_PEERS.clone()).unwrap(); store.update_peers(&BASE_PEERS).unwrap();
let mut modified = BASE_PEERS.clone(); let mut modified = BASE_PEERS.clone();
modified[0].contents.public_key = "foo".to_string(); modified[0].contents.public_key = "foo".to_string();
// Should NOT work, since peer is unmodified. // Should NOT work, since peer is unmodified.
assert!(store.update_peers(modified).is_err()); assert!(store.update_peers(&modified).is_err());
} }
#[test] #[test]
@ -206,7 +207,7 @@ mod tests {
DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap(); DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap();
// Should work, since peer is unmodified. // Should work, since peer is unmodified.
store.update_peers(vec![]).unwrap(); store.update_peers(&[]).unwrap();
let new_peers = BASE_PEERS let new_peers = BASE_PEERS
.iter() .iter()
.cloned() .cloned()

View File

@ -4,13 +4,17 @@ use dialoguer::{Confirm, Input};
use hostsfile::HostsBuilder; use hostsfile::HostsBuilder;
use indoc::eprintdoc; use indoc::eprintdoc;
use shared::{ use shared::{
interface_config::InterfaceConfig, prompts, AddAssociationOpts, AddCidrOpts, AddPeerOpts, interface_config::InterfaceConfig,
Association, AssociationContents, Cidr, CidrTree, DeleteCidrOpts, EndpointContents, prompts,
InstallOpts, Interface, IoErrorContext, NetworkOpt, Peer, PeerDiff, RedeemContents, wg::{DeviceExt, PeerInfoExt},
RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR, REDEEM_TRANSITION_WAIT, AddAssociationOpts, AddCidrOpts, AddPeerOpts, Association, AssociationContents, Cidr, CidrTree,
DeleteCidrOpts, Endpoint, EndpointContents, InstallOpts, Interface, IoErrorContext, NetworkOpt,
Peer, RedeemContents, RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR,
REDEEM_TRANSITION_WAIT,
}; };
use std::{ use std::{
fmt, io, fmt, io,
net::SocketAddr,
path::{Path, PathBuf}, path::{Path, PathBuf},
thread, thread,
time::Duration, time::Duration,
@ -19,9 +23,11 @@ use structopt::{clap::AppSettings, StructOpt};
use wgctrl::{Device, DeviceUpdate, InterfaceName, PeerConfigBuilder, PeerInfo}; use wgctrl::{Device, DeviceUpdate, InterfaceName, PeerConfigBuilder, PeerInfo};
mod data_store; mod data_store;
mod nat;
mod util; mod util;
use data_store::DataStore; use data_store::DataStore;
use nat::NatTraverse;
use shared::{wg, Error}; use shared::{wg, Error};
use util::{human_duration, human_size, Api}; use util::{human_duration, human_size, Api};
@ -484,70 +490,13 @@ fn fetch(
let mut store = DataStore::open_or_create(interface)?; let mut store = DataStore::open_or_create(interface)?;
let State { peers, cidrs } = Api::new(&config.server).http("GET", "/user/state")?; let State { peers, cidrs } = Api::new(&config.server).http("GET", "/user/state")?;
let device_info = Device::get(interface, network.backend).with_str(interface.as_str_lossy())?; let device = Device::get(interface, network.backend)?;
let interface_public_key = device_info let modifications = device.diff(&peers);
.public_key
.as_ref()
.map(|k| k.to_base64())
.unwrap_or_default();
let existing_peers = &device_info.peers;
// Match existing peers (by pubkey) to new peer information from the server.
let modifications = peers.iter().filter_map(|peer| {
if peer.is_disabled || peer.public_key == interface_public_key {
None
} else {
let existing_peer = existing_peers
.iter()
.find(|p| p.config.public_key.to_base64() == peer.public_key);
PeerDiff::new(existing_peer, Some(peer)).unwrap()
}
});
// Remove any peers on the interface that aren't in the server's peer list any more.
let removals = existing_peers.iter().filter_map(|existing| {
let public_key = existing.config.public_key.to_base64();
if peers.iter().any(|p| p.public_key == public_key) {
None
} else {
PeerDiff::new(Some(existing), None).unwrap()
}
});
let updates = modifications let updates = modifications
.chain(removals)
.inspect(|diff| {
let public_key = diff.public_key().to_base64();
let text = match (diff.old, diff.new) {
(None, Some(_)) => "added".green(),
(Some(_), Some(_)) => "modified".yellow(),
(Some(_), None) => "removed".red(),
_ => unreachable!("PeerDiff can't be None -> None"),
};
// Grab the peer name from either the new data, or the historical data (if the peer is removed).
let peer_hostname = match diff.new {
Some(peer) => Some(peer.name.clone()),
_ => store
.peers()
.iter() .iter()
.find(|p| p.public_key == public_key) .inspect(|diff| util::print_peer_diff(&store, diff))
.map(|p| p.name.clone()), .cloned()
};
let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]");
log::info!(
" peer {} ({}...) was {}.",
peer_name.yellow(),
&public_key[..10].dimmed(),
text
);
for change in diff.changes() {
log::debug!(" {}", change);
}
})
.map(PeerConfigBuilder::from) .map(PeerConfigBuilder::from)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -566,10 +515,44 @@ fn fetch(
} else { } else {
log::info!("{}", "peers are already up to date.".green()); log::info!("{}", "peers are already up to date.".green());
} }
store.set_cidrs(cidrs); store.set_cidrs(cidrs);
store.update_peers(peers)?; store.update_peers(&peers)?;
store.write().with_str(interface.to_string())?; store.write().with_str(interface.to_string())?;
let candidates = wg::get_local_addrs()?
.into_iter()
.map(|addr| SocketAddr::from((addr, device.listen_port.unwrap_or(51820))).into())
.take(10)
.collect::<Vec<Endpoint>>();
log::info!(
"reporting {} interface address{} as NAT traversal candidates...",
candidates.len(),
if candidates.len() == 1 { "" } else { "es" }
);
log::debug!("candidates: {:?}", candidates);
match Api::new(&config.server).http_form::<_, ()>("PUT", "/user/candidates", &candidates) {
Err(ureq::Error::Status(404, _)) => {
log::warn!("your network is using an old version of innernet-server that doesn't support NAT traversal candidate reporting.")
},
Err(e) => return Err(e.into()),
_ => {},
}
log::debug!("viable ICE candidates: {:?}", candidates);
let mut nat_traverse = NatTraverse::new(interface, network.backend, &modifications)?;
loop {
if nat_traverse.is_finished() {
break;
}
log::info!(
"Attempting to establish connection with {} remaining unconnected peers...",
nat_traverse.remaining()
);
nat_traverse.step()?;
}
Ok(()) Ok(())
} }
@ -992,7 +975,9 @@ fn print_peer(peer: &PeerState, short: bool, level: usize) {
let pad = level * 2; let pad = level * 2;
let PeerState { peer, info } = peer; let PeerState { peer, info } = peer;
if short { if short {
let connected = PeerDiff::peer_recently_connected(info); let connected = info
.map(|info| !info.is_recently_connected())
.unwrap_or_default();
println_pad!( println_pad!(
pad, pad,

125
client/src/nat.rs Normal file
View File

@ -0,0 +1,125 @@
//! ICE-like NAT traversal logic.
//!
//! Doesn't follow the specific ICE protocol, but takes great inspiration from RFC 8445
//! and applies it to a protocol more specific to innernet.
use std::time::{Duration, Instant};
use anyhow::Error;
use shared::{
wg::{DeviceExt, PeerInfoExt},
Endpoint, Peer, PeerDiff,
};
use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder};
const STEP_INTERVAL: Duration = Duration::from_secs(5);
pub struct NatTraverse<'a> {
interface: &'a InterfaceName,
backend: Backend,
remaining: Vec<Peer>,
}
impl<'a> NatTraverse<'a> {
pub fn new(interface: &'a InterfaceName, backend: Backend, diffs: &[PeerDiff]) -> Result<Self, Error> {
let mut remaining: Vec<_> = diffs.iter().filter_map(|diff| diff.new).cloned().collect();
for peer in &mut remaining {
// Limit reported alternative candidates to 10.
peer.candidates.truncate(10);
// remove server-reported endpoint from elsewhere in the list if it existed.
let endpoint = peer.endpoint.clone();
peer.candidates
.retain(|addr| Some(addr) != endpoint.as_ref());
}
let mut nat_traverse = Self {
interface,
backend,
remaining,
};
nat_traverse.refresh_remaining()?;
Ok(nat_traverse)
}
pub fn is_finished(&self) -> bool {
self.remaining.is_empty()
}
pub fn remaining(&self) -> usize {
self.remaining.len()
}
/// Refreshes the current state of candidate traversal attempts, returning
/// the peers that have been exhausted of all options (not included are
/// peers that have successfully connected, or peers removed from the interface).
fn refresh_remaining(&mut self) -> Result<Vec<Peer>, Error> {
let device = Device::get(self.interface, self.backend)?;
// Remove connected and missing peers
self.remaining.retain(|peer| {
if let Some(peer_info) = device.get_peer(&peer.public_key) {
let recently_connected = peer_info.is_recently_connected();
if recently_connected {
log::debug!(
"peer {} removed from NAT traverser (connected!).",
peer.name
);
}
!recently_connected
} else {
log::debug!(
"peer {} removed from NAT traverser (no longer on interface).",
peer.name
);
false
}
});
let (exhausted, remaining): (Vec<_>, Vec<_>) = self
.remaining
.drain(..)
.partition(|peer| peer.candidates.is_empty());
self.remaining = remaining;
Ok(exhausted)
}
pub fn step(&mut self) -> Result<(), Error> {
let exhuasted = self.refresh_remaining()?;
// Reset peer endpoints that had no viable candidates back to the server-reported one, if it exists.
let reset_updates = exhuasted
.into_iter()
.filter_map(|peer| set_endpoint(&peer.public_key, peer.endpoint.as_ref()));
// Set all peers' endpoints to their next available candidate.
let candidate_updates = self.remaining.iter_mut().filter_map(|peer| {
let endpoint = peer.candidates.pop();
set_endpoint(&peer.public_key, endpoint.as_ref())
});
let updates: Vec<_> = reset_updates.chain(candidate_updates).collect();
DeviceUpdate::new()
.add_peers(&updates)
.apply(self.interface, self.backend)?;
let start = Instant::now();
while start.elapsed() < STEP_INTERVAL {
self.refresh_remaining()?;
if self.is_finished() {
log::debug!("NAT traverser is finished!");
break;
}
std::thread::sleep(Duration::from_millis(100));
}
Ok(())
}
}
/// Return a PeerConfigBuilder if an endpoint exists and resolves successfully.
fn set_endpoint(public_key: &str, endpoint: Option<&Endpoint>) -> Option<PeerConfigBuilder> {
endpoint
.and_then(|endpoint| endpoint.resolve().ok())
.map(|addr| {
PeerConfigBuilder::new(&Key::from_base64(public_key).unwrap()).set_endpoint(addr)
})
}

View File

@ -1,9 +1,9 @@
use crate::{ClientError, Error}; use crate::data_store::DataStore;
use colored::*; use colored::*;
use indoc::eprintdoc; use indoc::eprintdoc;
use log::{Level, LevelFilter}; use log::{Level, LevelFilter};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use shared::{interface_config::ServerInfo, INNERNET_PUBKEY_HEADER}; use shared::{interface_config::ServerInfo, PeerDiff, INNERNET_PUBKEY_HEADER};
use std::{io, time::Duration}; use std::{io, time::Duration};
use ureq::{Agent, AgentBuilder}; use ureq::{Agent, AgentBuilder};
@ -137,6 +137,39 @@ pub fn permissions_helptext(e: &io::Error) {
} }
} }
pub fn print_peer_diff(store: &DataStore, diff: &PeerDiff) {
let public_key = diff.public_key().to_base64();
let text = match (diff.old, diff.new) {
(None, Some(_)) => "added".green(),
(Some(_), Some(_)) => "modified".yellow(),
(Some(_), None) => "removed".red(),
_ => unreachable!("PeerDiff can't be None -> None"),
};
// Grab the peer name from either the new data, or the historical data (if the peer is removed).
let peer_hostname = match diff.new {
Some(peer) => Some(peer.name.clone()),
None => store
.peers()
.iter()
.find(|p| p.public_key == public_key)
.map(|p| p.name.clone()),
};
let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]");
log::info!(
" peer {} ({}...) was {}.",
peer_name.yellow(),
&public_key[..10].dimmed(),
text
);
for change in diff.changes() {
log::debug!(" {}", change);
}
}
pub struct Api<'a> { pub struct Api<'a> {
agent: Agent, agent: Agent,
server: &'a ServerInfo, server: &'a ServerInfo,
@ -151,7 +184,7 @@ impl<'a> Api<'a> {
Self { agent, server } Self { agent, server }
} }
pub fn http<T: DeserializeOwned>(&self, verb: &str, endpoint: &str) -> Result<T, Error> { pub fn http<T: DeserializeOwned>(&self, verb: &str, endpoint: &str) -> Result<T, ureq::Error> {
self.request::<(), _>(verb, endpoint, None) self.request::<(), _>(verb, endpoint, None)
} }
@ -160,7 +193,7 @@ impl<'a> Api<'a> {
verb: &str, verb: &str,
endpoint: &str, endpoint: &str,
form: S, form: S,
) -> Result<T, Error> { ) -> Result<T, ureq::Error> {
self.request(verb, endpoint, Some(form)) self.request(verb, endpoint, Some(form))
} }
@ -169,7 +202,7 @@ impl<'a> Api<'a> {
verb: &str, verb: &str,
endpoint: &str, endpoint: &str,
form: Option<S>, form: Option<S>,
) -> Result<T, Error> { ) -> Result<T, ureq::Error> {
let request = self let request = self
.agent .agent
.request( .request(
@ -179,7 +212,12 @@ impl<'a> Api<'a> {
.set(INNERNET_PUBKEY_HEADER, &self.server.public_key); .set(INNERNET_PUBKEY_HEADER, &self.server.public_key);
let response = if let Some(form) = form { let response = if let Some(form) = form {
request.send_json(serde_json::to_value(form)?)? request.send_json(serde_json::to_value(form).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("failed to serialize JSON request: {}", e),
)
})?)?
} else { } else {
request.call()? request.call()?
}; };
@ -190,10 +228,13 @@ impl<'a> Api<'a> {
response = "null".into(); response = "null".into();
} }
Ok(serde_json::from_str(&response).map_err(|e| { Ok(serde_json::from_str(&response).map_err(|e| {
ClientError(format!( io::Error::new(
io::ErrorKind::InvalidData,
format!(
"failed to deserialize JSON response from the server: {}, response={}", "failed to deserialize JSON response from the server: {}, response={}",
e, &response e, &response
)) ),
)
})?) })?)
} }
} }

View File

@ -36,11 +36,20 @@ pub async fn routes(
let form = form_body(req).await?; let form = form_body(req).await?;
handlers::endpoint(form, session).await handlers::endpoint(form, session).await
}, },
(&Method::PUT, Some("candidates")) => {
if !session.user_capable() {
return Err(ServerError::Unauthorized);
}
let form = form_body(req).await?;
handlers::candidates(form, session).await
},
_ => Err(ServerError::NotFound), _ => Err(ServerError::NotFound),
} }
} }
mod handlers { mod handlers {
use shared::Endpoint;
use super::*; use super::*;
/// Get the current state of the network, in the eyes of the current peer. /// Get the current state of the network, in the eyes of the current peer.
@ -115,14 +124,29 @@ mod handlers {
status_response(StatusCode::NO_CONTENT) status_response(StatusCode::NO_CONTENT)
} }
/// Redeems an invitation. An invitation includes a WireGuard keypair generated by either the server /// Report any other endpoint candidates that can be tried by peers to connect.
/// or a peer with admin rights. /// Currently limited to 10 candidates max.
/// pub async fn candidates(
/// Redemption is the process of an invitee generating their own keypair and exchanging their temporary contents: Vec<Endpoint>,
/// key with their permanent one. session: Session,
/// ) -> Result<Response<Body>, ServerError> {
/// Until this API endpoint is called, the invited peer will not show up to other peers, and once if contents.len() > 10 {
/// it is called and succeeds, it cannot be called again. return status_response(StatusCode::PAYLOAD_TOO_LARGE);
}
let conn = session.context.db.lock();
let mut selected_peer = DatabasePeer::get(&conn, session.peer.id)?;
selected_peer.update(
&conn,
PeerContents {
candidates: contents,
..selected_peer.contents.clone()
},
)?;
status_response(StatusCode::NO_CONTENT)
}
/// Force a specific endpoint to be reported by the server.
pub async fn endpoint( pub async fn endpoint(
contents: EndpointContents, contents: EndpointContents,
session: Session, session: Session,
@ -148,7 +172,7 @@ mod tests {
use super::*; use super::*;
use crate::{db::DatabaseAssociation, test}; use crate::{db::DatabaseAssociation, test};
use bytes::Buf; use bytes::Buf;
use shared::{AssociationContents, CidrContents, EndpointContents, Error}; use shared::{AssociationContents, CidrContents, Endpoint, EndpointContents, Error};
#[tokio::test] #[tokio::test]
async fn test_get_state_from_developer1() -> Result<(), Error> { async fn test_get_state_from_developer1() -> Result<(), Error> {
@ -406,4 +430,36 @@ mod tests {
assert_eq!(res.status(), StatusCode::UNAUTHORIZED); assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_candidates() -> Result<(), Error> {
let server = test::Server::new()?;
let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?;
assert_eq!(peer.candidates, vec![]);
let candidates = vec!["1.1.1.1:51820".parse::<Endpoint>().unwrap()];
assert_eq!(
server
.form_request(
test::DEVELOPER1_PEER_IP,
"PUT",
"/v1/user/candidates",
&candidates
)
.await
.status(),
StatusCode::NO_CONTENT
);
let res = server
.request(test::DEVELOPER1_PEER_IP, "GET", "/v1/user/state")
.await;
assert_eq!(res.status(), StatusCode::OK);
let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?;
assert_eq!(peer.candidates, candidates);
Ok(())
}
} }

View File

@ -8,11 +8,13 @@ pub use peer::DatabasePeer;
use rusqlite::params; use rusqlite::params;
const INVITE_EXPIRATION_VERSION: usize = 1; const INVITE_EXPIRATION_VERSION: usize = 1;
const ENDPOINT_CANDIDATES_VERSION: usize = 2;
pub const CURRENT_VERSION: usize = INVITE_EXPIRATION_VERSION; pub const CURRENT_VERSION: usize = ENDPOINT_CANDIDATES_VERSION;
pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
let old_version: usize = conn.pragma_query_value(None, "user_version", |r| r.get(0))?; let old_version: usize = conn.pragma_query_value(None, "user_version", |r| r.get(0))?;
log::debug!("user_version: {}", old_version);
if old_version < INVITE_EXPIRATION_VERSION { if old_version < INVITE_EXPIRATION_VERSION {
conn.execute( conn.execute(
@ -21,8 +23,12 @@ pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error>
)?; )?;
} }
conn.pragma_update(None, "user_version", &CURRENT_VERSION)?; if old_version < ENDPOINT_CANDIDATES_VERSION {
conn.execute("ALTER TABLE peers ADD COLUMN candidates TEXT", params![])?;
}
if old_version != CURRENT_VERSION { if old_version != CURRENT_VERSION {
conn.pragma_update(None, "user_version", &CURRENT_VERSION)?;
log::info!( log::info!(
"migrated db version from {} to {}", "migrated db version from {} to {}",
old_version, old_version,

View File

@ -2,8 +2,8 @@ use super::DatabaseCidr;
use crate::ServerError; use crate::ServerError;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use rusqlite::{params, Connection}; use rusqlite::{params, types::Type, Connection};
use shared::{Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS}; use shared::{Endpoint, Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS};
use std::{ use std::{
net::IpAddr, net::IpAddr,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
@ -22,12 +22,27 @@ pub static CREATE_TABLE_SQL: &str = "CREATE TABLE peers (
is_disabled INTEGER DEFAULT 0 NOT NULL, /* Is the peer disabled? (peers cannot be deleted) */ is_disabled INTEGER DEFAULT 0 NOT NULL, /* Is the peer disabled? (peers cannot be deleted) */
is_redeemed INTEGER DEFAULT 0 NOT NULL, /* Has the peer redeemed their invite yet? */ is_redeemed INTEGER DEFAULT 0 NOT NULL, /* Has the peer redeemed their invite yet? */
invite_expires INTEGER, /* The UNIX time that an invited peer can no longer redeem. */ invite_expires INTEGER, /* The UNIX time that an invited peer can no longer redeem. */
candidates TEXT, /* A list of additional endpoints that peers can use to connect. */
FOREIGN KEY (cidr_id) FOREIGN KEY (cidr_id)
REFERENCES cidrs (id) REFERENCES cidrs (id)
ON UPDATE RESTRICT ON UPDATE RESTRICT
ON DELETE RESTRICT ON DELETE RESTRICT
)"; )";
pub static COLUMNS: &[&str] = &[
"id",
"name",
"ip",
"cidr_id",
"public_key",
"endpoint",
"is_admin",
"is_disabled",
"is_redeemed",
"invite_expires",
"candidates",
];
lazy_static! { lazy_static! {
/// Regex to match the requirements of hostname(7), needed to have peers also be reachable hostnames. /// Regex to match the requirements of hostname(7), needed to have peers also be reachable hostnames.
/// Note that the full length also must be maximum 63 characters, which this regex does not check. /// Note that the full length also must be maximum 63 characters, which this regex does not check.
@ -71,6 +86,7 @@ impl DatabasePeer {
is_disabled, is_disabled,
is_redeemed, is_redeemed,
invite_expires, invite_expires,
candidates,
.. ..
} = &contents; } = &contents;
log::info!("creating peer {:?}", contents); log::info!("creating peer {:?}", contents);
@ -96,8 +112,13 @@ impl DatabasePeer {
.flatten() .flatten()
.map(|t| t.as_secs()); .map(|t| t.as_secs());
let candidates = serde_json::to_string(candidates)?;
conn.execute( conn.execute(
"INSERT INTO peers (name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", &format!(
"INSERT INTO peers ({}) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
COLUMNS[1..].join(", ")
),
params![ params![
&**name, &**name,
ip.to_string(), ip.to_string(),
@ -108,6 +129,7 @@ impl DatabasePeer {
is_disabled, is_disabled,
is_redeemed, is_redeemed,
invite_expires, invite_expires,
candidates,
], ],
)?; )?;
let id = conn.last_insert_rowid(); let id = conn.last_insert_rowid();
@ -135,17 +157,21 @@ impl DatabasePeer {
endpoint: contents.endpoint, endpoint: contents.endpoint,
is_admin: contents.is_admin, is_admin: contents.is_admin,
is_disabled: contents.is_disabled, is_disabled: contents.is_disabled,
candidates: contents.candidates,
..self.contents.clone() ..self.contents.clone()
}; };
let new_candidates = serde_json::to_string(&new_contents.candidates)?;
conn.execute( conn.execute(
"UPDATE peers SET "UPDATE peers SET
name = ?1, name = ?2,
endpoint = ?2, endpoint = ?3,
is_admin = ?3, is_admin = ?4,
is_disabled = ?4 is_disabled = ?5,
WHERE id = ?5", candidates = ?6
WHERE id = ?1",
params![ params![
self.id,
&*new_contents.name, &*new_contents.name,
new_contents new_contents
.endpoint .endpoint
@ -153,7 +179,7 @@ impl DatabasePeer {
.map(|endpoint| endpoint.to_string()), .map(|endpoint| endpoint.to_string()),
new_contents.is_admin, new_contents.is_admin,
new_contents.is_disabled, new_contents.is_disabled,
self.id, new_candidates,
], ],
)?; )?;
@ -198,11 +224,11 @@ impl DatabasePeer {
let name = row let name = row
.get::<_, String>(1)? .get::<_, String>(1)?
.parse() .parse()
.map_err(|_| rusqlite::Error::ExecuteReturnedResults)?; .map_err(|_| rusqlite::Error::InvalidColumnType(1, "hostname".into(), Type::Text))?;
let ip: IpAddr = row let ip: IpAddr = row
.get::<_, String>(2)? .get::<_, String>(2)?
.parse() .parse()
.map_err(|_| rusqlite::Error::ExecuteReturnedResults)?; .map_err(|_| rusqlite::Error::InvalidColumnType(2, "ip".into(), Type::Text))?;
let cidr_id = row.get(3)?; let cidr_id = row.get(3)?;
let public_key = row.get(4)?; let public_key = row.get(4)?;
let endpoint = row let endpoint = row
@ -214,6 +240,10 @@ impl DatabasePeer {
let invite_expires = row let invite_expires = row
.get::<_, Option<u64>>(9)? .get::<_, Option<u64>>(9)?
.map(|unixtime| SystemTime::UNIX_EPOCH + Duration::from_secs(unixtime)); .map(|unixtime| SystemTime::UNIX_EPOCH + Duration::from_secs(unixtime));
let candidates_str: String = row.get(10)?;
let candidates: Vec<Endpoint> = serde_json::from_str(&candidates_str).map_err(|_| {
rusqlite::Error::InvalidColumnType(10, "candidates (json)".into(), Type::Text)
})?;
let persistent_keepalive_interval = Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS); let persistent_keepalive_interval = Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS);
@ -230,6 +260,7 @@ impl DatabasePeer {
is_disabled, is_disabled,
is_redeemed, is_redeemed,
invite_expires, invite_expires,
candidates,
}, },
} }
.into()) .into())
@ -237,10 +268,7 @@ impl DatabasePeer {
pub fn get(conn: &Connection, id: i64) -> Result<Self, ServerError> { pub fn get(conn: &Connection, id: i64) -> Result<Self, ServerError> {
let result = conn.query_row( let result = conn.query_row(
"SELECT &format!("SELECT {} FROM peers WHERE id = ?1", COLUMNS.join(", ")),
id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires
FROM peers
WHERE id = ?1",
params![id], params![id],
Self::from_row, Self::from_row,
)?; )?;
@ -250,10 +278,7 @@ impl DatabasePeer {
pub fn get_from_ip(conn: &Connection, ip: IpAddr) -> Result<Self, rusqlite::Error> { pub fn get_from_ip(conn: &Connection, ip: IpAddr) -> Result<Self, rusqlite::Error> {
let result = conn.query_row( let result = conn.query_row(
"SELECT &format!("SELECT {} FROM peers WHERE ip = ?1", COLUMNS.join(", ")),
id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires
FROM peers
WHERE ip = ?1",
params![ip.to_string()], params![ip.to_string()],
Self::from_row, Self::from_row,
)?; )?;
@ -271,7 +296,7 @@ impl DatabasePeer {
// //
// NOTE that a forced association is created with the special "infra" CIDR with id 2 (1 being the root). // NOTE that a forced association is created with the special "infra" CIDR with id 2 (1 being the root).
let mut stmt = conn.prepare_cached( let mut stmt = conn.prepare_cached(
"WITH &format!("WITH
parent_of(id, parent) AS ( parent_of(id, parent) AS (
SELECT id, parent FROM cidrs WHERE id = ?1 SELECT id, parent FROM cidrs WHERE id = ?1
UNION ALL UNION ALL
@ -289,10 +314,12 @@ impl DatabasePeer {
UNION UNION
SELECT id FROM cidrs, associated_subcidrs WHERE cidrs.parent=associated_subcidrs.cidr_id SELECT id FROM cidrs, associated_subcidrs WHERE cidrs.parent=associated_subcidrs.cidr_id
) )
SELECT DISTINCT peers.id, peers.name, peers.ip, peers.cidr_id, peers.public_key, peers.endpoint, peers.is_admin, peers.is_disabled, peers.is_redeemed, peers.invite_expires SELECT DISTINCT {}
FROM peers FROM peers
JOIN associated_subcidrs ON peers.cidr_id=associated_subcidrs.cidr_id JOIN associated_subcidrs ON peers.cidr_id=associated_subcidrs.cidr_id
WHERE peers.is_disabled = 0 AND peers.is_redeemed = 1;", WHERE peers.is_disabled = 0 AND peers.is_redeemed = 1;",
COLUMNS.iter().map(|col| format!("peers.{}", col)).collect::<Vec<_>>().join(", ")
),
)?; )?;
let peers = stmt let peers = stmt
.query_map(params![self.cidr_id], Self::from_row)? .query_map(params![self.cidr_id], Self::from_row)?
@ -301,9 +328,7 @@ impl DatabasePeer {
} }
pub fn list(conn: &Connection) -> Result<Vec<Self>, ServerError> { pub fn list(conn: &Connection) -> Result<Vec<Self>, ServerError> {
let mut stmt = conn.prepare_cached( let mut stmt = conn.prepare_cached(&format!("SELECT {} FROM peers", COLUMNS.join(", ")))?;
"SELECT id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires FROM peers",
)?;
let peer_iter = stmt.query_map(params![], Self::from_row)?; let peer_iter = stmt.query_map(params![], Self::from_row)?;
Ok(peer_iter.collect::<Result<_, _>>()?) Ok(peer_iter.collect::<Result<_, _>>()?)

View File

@ -17,6 +17,7 @@ fn create_database<P: AsRef<Path>>(
conn.execute(db::association::CREATE_TABLE_SQL, params![])?; conn.execute(db::association::CREATE_TABLE_SQL, params![])?;
conn.execute(db::cidr::CREATE_TABLE_SQL, params![])?; conn.execute(db::cidr::CREATE_TABLE_SQL, params![])?;
conn.pragma_update(None, "user_version", &db::CURRENT_VERSION)?; conn.pragma_update(None, "user_version", &db::CURRENT_VERSION)?;
log::debug!("set database version to db::CURRENT_VERSION");
Ok(conn) Ok(conn)
} }
@ -89,6 +90,7 @@ fn populate_database(conn: &Connection, db_init_data: DbInitData) -> Result<(),
is_redeemed: true, is_redeemed: true,
persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS), persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS),
invite_expires: None, invite_expires: None,
candidates: vec![],
}, },
) )
.map_err(|_| anyhow!("failed to create innernet peer."))?; .map_err(|_| anyhow!("failed to create innernet peer."))?;

View File

@ -476,9 +476,11 @@ async fn serve(
network: NetworkOpt, network: NetworkOpt,
) -> Result<(), Error> { ) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(&interface))?; let config = ConfigFile::from_file(conf.config_path(&interface))?;
log::debug!("opening database connection...");
let conn = open_database_connection(&interface, conf)?; let conn = open_database_connection(&interface, conf)?;
let peers = DatabasePeer::list(&conn)?; let peers = DatabasePeer::list(&conn)?;
log::debug!("peers listed...");
let peer_configs = peers let peer_configs = peers
.iter() .iter()
.map(|peer| peer.deref().into()) .map(|peer| peer.deref().into())

View File

@ -231,6 +231,7 @@ pub fn peer_contents(
is_disabled: false, is_disabled: false,
is_redeemed: true, is_redeemed: true,
invite_expires: None, invite_expires: None,
candidates: vec![],
}) })
} }

View File

@ -29,3 +29,6 @@ netlink-sys = "0.7"
netlink-packet-core = "0.2" netlink-packet-core = "0.2"
netlink-packet-route = "0.7" netlink-packet-route = "0.7"
wgctrl-sys = { path = "../wgctrl-sys" } wgctrl-sys = { path = "../wgctrl-sys" }
[target.'cfg(target_os = "macos")'.dependencies]
nix = "0.22"

View File

@ -3,11 +3,14 @@ use netlink_packet_core::{
NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST,
}; };
use netlink_packet_route::{ use netlink_packet_route::{
address, constants::*, link, route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, address,
RouteHeader, RouteMessage, RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN, constants::*,
link::{self, nlas::State},
route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, RouteHeader, RouteMessage,
RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN,
}; };
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
use std::io; use std::{io, net::IpAddr};
use wgctrl::InterfaceName; use wgctrl::InterfaceName;
fn if_nametoindex(interface: &InterfaceName) -> Result<u32, io::Error> { fn if_nametoindex(interface: &InterfaceName) -> Result<u32, io::Error> {
@ -23,7 +26,7 @@ fn if_nametoindex(interface: &InterfaceName) -> Result<u32, io::Error> {
fn netlink_call( fn netlink_call(
message: RtnlMessage, message: RtnlMessage,
flags: Option<u16>, flags: Option<u16>,
) -> Result<NetlinkMessage<RtnlMessage>, io::Error> { ) -> Result<Vec<NetlinkMessage<RtnlMessage>>, io::Error> {
let mut req = NetlinkMessage::from(message); let mut req = NetlinkMessage::from(message);
req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE); req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
req.finalize(); req.finalize();
@ -32,10 +35,10 @@ fn netlink_call(
let len = req.buffer_len(); let len = req.buffer_len();
log::debug!("netlink request: {:?}", req); log::debug!("netlink request: {:?}", req);
let socket = Socket::new(NETLINK_ROUTE).unwrap(); let socket = Socket::new(NETLINK_ROUTE)?;
let kernel_addr = SocketAddr::new(0, 0); let kernel_addr = SocketAddr::new(0, 0);
socket.connect(&kernel_addr)?; socket.connect(&kernel_addr)?;
let n_sent = socket.send(&buf[..len], 0).unwrap(); let n_sent = socket.send(&buf[..len], 0)?;
if n_sent != len { if n_sent != len {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::UnexpectedEof, io::ErrorKind::UnexpectedEof,
@ -43,14 +46,31 @@ fn netlink_call(
)); ));
} }
let n_received = socket.recv(&mut buf[..], 0).unwrap(); let mut responses = vec![];
let response = NetlinkMessage::<RtnlMessage>::deserialize(&buf[..n_received]) loop {
let n_received = socket.recv(&mut buf[..], 0)?;
let mut offset = 0;
loop {
let bytes = &buf[offset..];
let response = NetlinkMessage::<RtnlMessage>::deserialize(bytes)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
responses.push(response.clone());
log::trace!("netlink response: {:?}", response); log::trace!("netlink response: {:?}", response);
if let NetlinkPayload::Error(e) = response.payload { match response.payload {
return Err(e.to_io()); // We've parsed all parts of the response and can leave the loop.
NetlinkPayload::Ack(_) | NetlinkPayload::Done => return Ok(responses),
NetlinkPayload::Error(e) => return Err(e.into()),
_ => {},
}
offset += response.header.length as usize;
if offset == n_received || response.header.length == 0 {
// We've fully parsed the datagram, but there may be further datagrams
// with additional netlink response parts.
log::debug!("breaking inner loop");
break;
}
}
} }
Ok(response)
} }
pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> { pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> {
@ -127,3 +147,87 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result<bool, io:
Err(e) => Err(e), Err(e) => Err(e),
} }
} }
fn get_links() -> Result<Vec<String>, io::Error> {
let link_responses = netlink_call(
RtnlMessage::GetLink(LinkMessage::default()),
Some(NLM_F_DUMP | NLM_F_REQUEST),
)?;
let links = link_responses
.into_iter()
// Filter out non-link messages
.filter_map(|response| match response {
NetlinkMessage {
payload: NetlinkPayload::InnerMessage(RtnlMessage::NewLink(link)),
..
} => Some(link),
_ => None,
})
// Filter out loopback links
.filter_map(|link| if link.header.flags & IFF_LOOPBACK == 0 {
Some(link.nlas)
} else {
None
})
// Find and filter out addresses for interfaces
.filter(|nlas| nlas.iter().any(|nla| nla == &link::nlas::Nla::OperState(State::Up)))
.filter_map(|nlas| nlas.iter().find_map(|nla| match nla {
link::nlas::Nla::IfName(name) => Some(name.clone()),
_ => None,
}))
.collect::<Vec<_>>();
Ok(links)
}
pub fn get_local_addrs() -> Result<Vec<IpAddr>, io::Error> {
let links = get_links()?;
let addr_responses = netlink_call(
RtnlMessage::GetAddress(AddressMessage::default()),
Some(NLM_F_DUMP | NLM_F_REQUEST),
)?;
let addrs = addr_responses
.into_iter()
// Filter out non-link messages
.filter_map(|response| match response {
NetlinkMessage {
payload: NetlinkPayload::InnerMessage(RtnlMessage::NewAddress(addr)),
..
} => Some(addr),
_ => None,
})
// Filter out non-global-scoped addresses
.filter_map(|link| if link.header.scope == RT_SCOPE_UNIVERSE {
Some(link.nlas)
} else {
None
})
// Only select addresses for helpful links
.filter(|nlas| nlas.iter().any(|nla| matches!(nla, address::nlas::Nla::Label(label) if links.contains(label))))
.filter_map(|nlas| nlas.iter().find_map(|nla| match nla {
address::nlas::Nla::Address(name) if name.len() == 4 => {
let mut addr = [0u8; 4];
addr.copy_from_slice(name);
Some(IpAddr::V4(addr.into()))
},
address::nlas::Nla::Address(name) if name.len() == 16 => {
let mut addr = [0u8; 16];
addr.copy_from_slice(name);
Some(IpAddr::V6(addr.into()))
},
_ => None,
}))
.collect::<Vec<_>>();
Ok(addrs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_local_addrs() {
let addrs = get_local_addrs().unwrap();
println!("{:?}", addrs);
}
}

View File

@ -285,6 +285,7 @@ pub fn add_peer(
is_redeemed: false, is_redeemed: false,
persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS), persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS),
invite_expires: Some(SystemTime::now() + invite_expires.into()), invite_expires: Some(SystemTime::now() + invite_expires.into()),
candidates: vec![],
}; };
Ok( Ok(

View File

@ -7,7 +7,7 @@ use std::{
fmt::{self, Display, Formatter}, fmt::{self, Display, Formatter},
io, io,
net::{IpAddr, SocketAddr, ToSocketAddrs}, net::{IpAddr, SocketAddr, ToSocketAddrs},
ops::Deref, ops::{Deref, DerefMut},
path::Path, path::Path,
str::FromStr, str::FromStr,
time::{Duration, SystemTime}, time::{Duration, SystemTime},
@ -20,6 +20,8 @@ use wgctrl::{
PeerInfo, PeerInfo,
}; };
use crate::wg::PeerInfoExt;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Interface { pub struct Interface {
name: InterfaceName, name: InterfaceName,
@ -408,6 +410,8 @@ pub struct PeerContents {
pub is_disabled: bool, pub is_disabled: bool,
pub is_redeemed: bool, pub is_redeemed: bool,
pub invite_expires: Option<SystemTime>, pub invite_expires: Option<SystemTime>,
#[serde(default)]
pub candidates: Vec<Endpoint>,
} }
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
@ -426,6 +430,12 @@ impl Deref for Peer {
} }
} }
impl DerefMut for Peer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.contents
}
}
impl Display for Peer { impl Display for Peer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} ({})", &self.name, &self.public_key) write!(f, "{} ({})", &self.name, &self.public_key)
@ -497,19 +507,6 @@ impl<'a> PeerDiff<'a> {
} }
} }
/// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this
/// as a heuristic for "currentness" without relying on heavier things like ICMP.
pub fn peer_recently_connected(peer: &Option<&PeerInfo>) -> bool {
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
let last_handshake = peer
.and_then(|p| p.stats.last_handshake_time)
.and_then(|t| t.elapsed().ok())
.unwrap_or_else(|| SystemTime::UNIX_EPOCH.elapsed().unwrap());
last_handshake <= REJECT_AFTER_TIME
}
pub fn public_key(&self) -> &Key { pub fn public_key(&self) -> &Key {
self.builder.public_key() self.builder.public_key()
} }
@ -570,7 +567,10 @@ impl<'a> PeerDiff<'a> {
} }
// We won't update the endpoint if there's already a stable connection. // We won't update the endpoint if there's already a stable connection.
if !Self::peer_recently_connected(&old_info) { if !old_info
.map(|info| info.is_recently_connected())
.unwrap_or_default()
{
let resolved = new.endpoint.as_ref().and_then(|e| e.resolve().ok()); let resolved = new.endpoint.as_ref().and_then(|e| e.resolve().ok());
if let Some(addr) = resolved { if let Some(addr) = resolved {
if old.is_none() || matches!(old, Some(old) if old.endpoint != resolved) { if old.is_none() || matches!(old, Some(old) if old.endpoint != resolved) {
@ -772,6 +772,7 @@ mod tests {
is_disabled: false, is_disabled: false,
is_redeemed: true, is_redeemed: true,
invite_expires: None, invite_expires: None,
candidates: vec![],
}, },
}; };
let builder = let builder =
@ -806,6 +807,7 @@ mod tests {
is_disabled: false, is_disabled: false,
is_redeemed: true, is_redeemed: true,
invite_expires: None, invite_expires: None,
candidates: vec![],
}, },
}; };
let builder = let builder =
@ -840,6 +842,7 @@ mod tests {
is_disabled: false, is_disabled: false,
is_redeemed: true, is_redeemed: true,
invite_expires: None, invite_expires: None,
candidates: vec![],
}, },
}; };
let builder = let builder =

View File

@ -1,10 +1,11 @@
use crate::{Error, IoErrorContext, NetworkOpt}; use crate::{Error, IoErrorContext, NetworkOpt, Peer, PeerDiff};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use std::{ use std::{
io, io,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
time::Duration,
}; };
use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, PeerConfigBuilder}; use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder, PeerInfo};
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
fn cmd(bin: &str, args: &[&str]) -> Result<std::process::Output, io::Error> { fn cmd(bin: &str, args: &[&str]) -> Result<std::process::Output, io::Error> {
@ -164,3 +165,97 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result<bool, io:
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub use super::netlink::add_route; pub use super::netlink::add_route;
#[cfg(target_os = "macos")]
pub fn get_local_addrs() -> Result<Vec<IpAddr>, io::Error> {
use nix::{net::if_::InterfaceFlags, sys::socket::SockAddr};
let addrs = nix::ifaddrs::getifaddrs()?
.inspect(|addr| println!("{:?}", addr))
.filter(|addr| {
addr.flags.contains(InterfaceFlags::IFF_UP)
&& !addr.flags.intersects(
InterfaceFlags::IFF_LOOPBACK
| InterfaceFlags::IFF_POINTOPOINT
| InterfaceFlags::IFF_PROMISC,
)
})
.filter_map(|addr| match addr.address {
Some(SockAddr::Inet(addr)) if addr.to_std().is_ipv4() => Some(addr.to_std().ip()),
_ => None,
})
.collect::<Vec<_>>();
Ok(addrs)
}
#[cfg(target_os = "linux")]
pub use super::netlink::get_local_addrs;
pub trait DeviceExt {
/// Diff the output of a wgctrl device with a list of server-reported peers.
fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec<PeerDiff<'a>>;
// /// Get a peer by their public key, a helper function.
fn get_peer(&self, public_key: &str) -> Option<&PeerInfo>;
}
impl DeviceExt for Device {
fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec<PeerDiff<'a>> {
let interface_public_key = self
.public_key
.as_ref()
.map(|k| k.to_base64())
.unwrap_or_default();
let existing_peers = &self.peers;
// Match existing peers (by pubkey) to new peer information from the server.
let modifications = peers.iter().filter_map(|peer| {
if peer.is_disabled || peer.public_key == interface_public_key {
None
} else {
let existing_peer = existing_peers
.iter()
.find(|p| p.config.public_key.to_base64() == peer.public_key);
PeerDiff::new(existing_peer, Some(peer)).unwrap()
}
});
// Remove any peers on the interface that aren't in the server's peer list any more.
let removals = existing_peers.iter().filter_map(|existing| {
let public_key = existing.config.public_key.to_base64();
if peers.iter().any(|p| p.public_key == public_key) {
None
} else {
PeerDiff::new(Some(existing), None).unwrap()
}
});
modifications.chain(removals).collect::<Vec<_>>()
}
fn get_peer(&self, public_key: &str) -> Option<&PeerInfo> {
Key::from_base64(public_key)
.ok()
.and_then(|key| self.peers.iter().find(|peer| peer.config.public_key == key))
}
}
pub trait PeerInfoExt {
/// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this
/// as a heuristic for "currentness" without relying on heavier things like ICMP.
fn is_recently_connected(&self) -> bool;
}
impl PeerInfoExt for PeerInfo {
fn is_recently_connected(&self) -> bool {
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
let last_handshake = self
.stats
.last_handshake_time
.and_then(|t| t.elapsed().ok())
.unwrap_or(Duration::MAX);
last_handshake <= REJECT_AFTER_TIME
}
}