NAT traversal: ICE-esque candidate selection (#134)

This change adds the ability for peers to report additional candidate endpoints for other peers to attempt connections with outside of the endpoint reported by the coordinating server.

While not a complete solution to the full spectrum of NAT traversal issues (TURN-esque proxying is still notably missing), it allows peers within the same NAT to connect to each other via their LAN addresses, which is a win nonetheless. In the future, more advanced candidate discovery could be used to punch through additional types of NAT cone types as well.

Co-authored-by: Matěj Laitl <matej@laitl.cz>
pull/136/head
Jake McGinty 2021-09-01 18:58:46 +09:00 committed by GitHub
parent fd06b8054d
commit 8903604caa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 623 additions and 150 deletions

27
Cargo.lock generated
View File

@ -525,9 +525,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.98"
version = "0.2.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790"
checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21"
[[package]]
name = "libsqlite3-sys"
@ -569,6 +569,15 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
[[package]]
name = "memoffset"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9"
dependencies = [
"autocfg",
]
[[package]]
name = "mio"
version = "0.7.13"
@ -639,6 +648,19 @@ dependencies = [
"log",
]
[[package]]
name = "nix"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7555d6c7164cc913be1ce7f95cbecdabda61eb2ccd89008524af306fb7f5031"
dependencies = [
"bitflags",
"cc",
"cfg-if",
"libc",
"memoffset",
]
[[package]]
name = "nom"
version = "6.1.2"
@ -999,6 +1021,7 @@ dependencies = [
"netlink-packet-core",
"netlink-packet-route",
"netlink-sys",
"nix",
"publicip",
"regex",
"serde",

View File

@ -80,7 +80,7 @@ impl DataStore {
///
/// Note, however, that this does not prevent a compromised server from adding a new
/// peer under its control, of course.
pub fn update_peers(&mut self, current_peers: Vec<Peer>) -> Result<(), Error> {
pub fn update_peers(&mut self, current_peers: &[Peer]) -> Result<(), Error> {
let peers = match &mut self.contents {
Contents::V1 { ref mut peers, .. } => peers,
};
@ -149,6 +149,7 @@ mod tests {
is_redeemed: true,
persistent_keepalive_interval: None,
invite_expires: None,
candidates: vec![],
}
}];
static ref BASE_CIDRS: Vec<Cidr> = vec![Cidr {
@ -167,7 +168,7 @@ mod tests {
assert_eq!(0, store.peers().len());
assert_eq!(0, store.cidrs().len());
store.update_peers(BASE_PEERS.to_owned()).unwrap();
store.update_peers(&BASE_PEERS).unwrap();
store.set_cidrs(BASE_CIDRS.to_owned());
store.write().unwrap();
}
@ -189,13 +190,13 @@ mod tests {
DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap();
// Should work, since peer is unmodified.
store.update_peers(BASE_PEERS.clone()).unwrap();
store.update_peers(&BASE_PEERS).unwrap();
let mut modified = BASE_PEERS.clone();
modified[0].contents.public_key = "foo".to_string();
// Should NOT work, since peer is unmodified.
assert!(store.update_peers(modified).is_err());
assert!(store.update_peers(&modified).is_err());
}
#[test]
@ -206,7 +207,7 @@ mod tests {
DataStore::open_with_path(&dir.path().join("peer_store.json"), false).unwrap();
// Should work, since peer is unmodified.
store.update_peers(vec![]).unwrap();
store.update_peers(&[]).unwrap();
let new_peers = BASE_PEERS
.iter()
.cloned()

View File

@ -4,13 +4,17 @@ use dialoguer::{Confirm, Input};
use hostsfile::HostsBuilder;
use indoc::eprintdoc;
use shared::{
interface_config::InterfaceConfig, prompts, AddAssociationOpts, AddCidrOpts, AddPeerOpts,
Association, AssociationContents, Cidr, CidrTree, DeleteCidrOpts, EndpointContents,
InstallOpts, Interface, IoErrorContext, NetworkOpt, Peer, PeerDiff, RedeemContents,
RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR, REDEEM_TRANSITION_WAIT,
interface_config::InterfaceConfig,
prompts,
wg::{DeviceExt, PeerInfoExt},
AddAssociationOpts, AddCidrOpts, AddPeerOpts, Association, AssociationContents, Cidr, CidrTree,
DeleteCidrOpts, Endpoint, EndpointContents, InstallOpts, Interface, IoErrorContext, NetworkOpt,
Peer, RedeemContents, RenamePeerOpts, State, WrappedIoError, CLIENT_CONFIG_DIR,
REDEEM_TRANSITION_WAIT,
};
use std::{
fmt, io,
net::SocketAddr,
path::{Path, PathBuf},
thread,
time::Duration,
@ -19,9 +23,11 @@ use structopt::{clap::AppSettings, StructOpt};
use wgctrl::{Device, DeviceUpdate, InterfaceName, PeerConfigBuilder, PeerInfo};
mod data_store;
mod nat;
mod util;
use data_store::DataStore;
use nat::NatTraverse;
use shared::{wg, Error};
use util::{human_duration, human_size, Api};
@ -484,70 +490,13 @@ fn fetch(
let mut store = DataStore::open_or_create(interface)?;
let State { peers, cidrs } = Api::new(&config.server).http("GET", "/user/state")?;
let device_info = Device::get(interface, network.backend).with_str(interface.as_str_lossy())?;
let interface_public_key = device_info
.public_key
.as_ref()
.map(|k| k.to_base64())
.unwrap_or_default();
let existing_peers = &device_info.peers;
// Match existing peers (by pubkey) to new peer information from the server.
let modifications = peers.iter().filter_map(|peer| {
if peer.is_disabled || peer.public_key == interface_public_key {
None
} else {
let existing_peer = existing_peers
.iter()
.find(|p| p.config.public_key.to_base64() == peer.public_key);
PeerDiff::new(existing_peer, Some(peer)).unwrap()
}
});
// Remove any peers on the interface that aren't in the server's peer list any more.
let removals = existing_peers.iter().filter_map(|existing| {
let public_key = existing.config.public_key.to_base64();
if peers.iter().any(|p| p.public_key == public_key) {
None
} else {
PeerDiff::new(Some(existing), None).unwrap()
}
});
let device = Device::get(interface, network.backend)?;
let modifications = device.diff(&peers);
let updates = modifications
.chain(removals)
.inspect(|diff| {
let public_key = diff.public_key().to_base64();
let text = match (diff.old, diff.new) {
(None, Some(_)) => "added".green(),
(Some(_), Some(_)) => "modified".yellow(),
(Some(_), None) => "removed".red(),
_ => unreachable!("PeerDiff can't be None -> None"),
};
// Grab the peer name from either the new data, or the historical data (if the peer is removed).
let peer_hostname = match diff.new {
Some(peer) => Some(peer.name.clone()),
_ => store
.peers()
.iter()
.find(|p| p.public_key == public_key)
.map(|p| p.name.clone()),
};
let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]");
log::info!(
" peer {} ({}...) was {}.",
peer_name.yellow(),
&public_key[..10].dimmed(),
text
);
for change in diff.changes() {
log::debug!(" {}", change);
}
})
.iter()
.inspect(|diff| util::print_peer_diff(&store, diff))
.cloned()
.map(PeerConfigBuilder::from)
.collect::<Vec<_>>();
@ -566,10 +515,44 @@ fn fetch(
} else {
log::info!("{}", "peers are already up to date.".green());
}
store.set_cidrs(cidrs);
store.update_peers(peers)?;
store.update_peers(&peers)?;
store.write().with_str(interface.to_string())?;
let candidates = wg::get_local_addrs()?
.into_iter()
.map(|addr| SocketAddr::from((addr, device.listen_port.unwrap_or(51820))).into())
.take(10)
.collect::<Vec<Endpoint>>();
log::info!(
"reporting {} interface address{} as NAT traversal candidates...",
candidates.len(),
if candidates.len() == 1 { "" } else { "es" }
);
log::debug!("candidates: {:?}", candidates);
match Api::new(&config.server).http_form::<_, ()>("PUT", "/user/candidates", &candidates) {
Err(ureq::Error::Status(404, _)) => {
log::warn!("your network is using an old version of innernet-server that doesn't support NAT traversal candidate reporting.")
},
Err(e) => return Err(e.into()),
_ => {},
}
log::debug!("viable ICE candidates: {:?}", candidates);
let mut nat_traverse = NatTraverse::new(interface, network.backend, &modifications)?;
loop {
if nat_traverse.is_finished() {
break;
}
log::info!(
"Attempting to establish connection with {} remaining unconnected peers...",
nat_traverse.remaining()
);
nat_traverse.step()?;
}
Ok(())
}
@ -992,7 +975,9 @@ fn print_peer(peer: &PeerState, short: bool, level: usize) {
let pad = level * 2;
let PeerState { peer, info } = peer;
if short {
let connected = PeerDiff::peer_recently_connected(info);
let connected = info
.map(|info| !info.is_recently_connected())
.unwrap_or_default();
println_pad!(
pad,

125
client/src/nat.rs Normal file
View File

@ -0,0 +1,125 @@
//! ICE-like NAT traversal logic.
//!
//! Doesn't follow the specific ICE protocol, but takes great inspiration from RFC 8445
//! and applies it to a protocol more specific to innernet.
use std::time::{Duration, Instant};
use anyhow::Error;
use shared::{
wg::{DeviceExt, PeerInfoExt},
Endpoint, Peer, PeerDiff,
};
use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder};
const STEP_INTERVAL: Duration = Duration::from_secs(5);
pub struct NatTraverse<'a> {
interface: &'a InterfaceName,
backend: Backend,
remaining: Vec<Peer>,
}
impl<'a> NatTraverse<'a> {
pub fn new(interface: &'a InterfaceName, backend: Backend, diffs: &[PeerDiff]) -> Result<Self, Error> {
let mut remaining: Vec<_> = diffs.iter().filter_map(|diff| diff.new).cloned().collect();
for peer in &mut remaining {
// Limit reported alternative candidates to 10.
peer.candidates.truncate(10);
// remove server-reported endpoint from elsewhere in the list if it existed.
let endpoint = peer.endpoint.clone();
peer.candidates
.retain(|addr| Some(addr) != endpoint.as_ref());
}
let mut nat_traverse = Self {
interface,
backend,
remaining,
};
nat_traverse.refresh_remaining()?;
Ok(nat_traverse)
}
pub fn is_finished(&self) -> bool {
self.remaining.is_empty()
}
pub fn remaining(&self) -> usize {
self.remaining.len()
}
/// Refreshes the current state of candidate traversal attempts, returning
/// the peers that have been exhausted of all options (not included are
/// peers that have successfully connected, or peers removed from the interface).
fn refresh_remaining(&mut self) -> Result<Vec<Peer>, Error> {
let device = Device::get(self.interface, self.backend)?;
// Remove connected and missing peers
self.remaining.retain(|peer| {
if let Some(peer_info) = device.get_peer(&peer.public_key) {
let recently_connected = peer_info.is_recently_connected();
if recently_connected {
log::debug!(
"peer {} removed from NAT traverser (connected!).",
peer.name
);
}
!recently_connected
} else {
log::debug!(
"peer {} removed from NAT traverser (no longer on interface).",
peer.name
);
false
}
});
let (exhausted, remaining): (Vec<_>, Vec<_>) = self
.remaining
.drain(..)
.partition(|peer| peer.candidates.is_empty());
self.remaining = remaining;
Ok(exhausted)
}
pub fn step(&mut self) -> Result<(), Error> {
let exhuasted = self.refresh_remaining()?;
// Reset peer endpoints that had no viable candidates back to the server-reported one, if it exists.
let reset_updates = exhuasted
.into_iter()
.filter_map(|peer| set_endpoint(&peer.public_key, peer.endpoint.as_ref()));
// Set all peers' endpoints to their next available candidate.
let candidate_updates = self.remaining.iter_mut().filter_map(|peer| {
let endpoint = peer.candidates.pop();
set_endpoint(&peer.public_key, endpoint.as_ref())
});
let updates: Vec<_> = reset_updates.chain(candidate_updates).collect();
DeviceUpdate::new()
.add_peers(&updates)
.apply(self.interface, self.backend)?;
let start = Instant::now();
while start.elapsed() < STEP_INTERVAL {
self.refresh_remaining()?;
if self.is_finished() {
log::debug!("NAT traverser is finished!");
break;
}
std::thread::sleep(Duration::from_millis(100));
}
Ok(())
}
}
/// Return a PeerConfigBuilder if an endpoint exists and resolves successfully.
fn set_endpoint(public_key: &str, endpoint: Option<&Endpoint>) -> Option<PeerConfigBuilder> {
endpoint
.and_then(|endpoint| endpoint.resolve().ok())
.map(|addr| {
PeerConfigBuilder::new(&Key::from_base64(public_key).unwrap()).set_endpoint(addr)
})
}

View File

@ -1,9 +1,9 @@
use crate::{ClientError, Error};
use crate::data_store::DataStore;
use colored::*;
use indoc::eprintdoc;
use log::{Level, LevelFilter};
use serde::{de::DeserializeOwned, Serialize};
use shared::{interface_config::ServerInfo, INNERNET_PUBKEY_HEADER};
use shared::{interface_config::ServerInfo, PeerDiff, INNERNET_PUBKEY_HEADER};
use std::{io, time::Duration};
use ureq::{Agent, AgentBuilder};
@ -137,6 +137,39 @@ pub fn permissions_helptext(e: &io::Error) {
}
}
pub fn print_peer_diff(store: &DataStore, diff: &PeerDiff) {
let public_key = diff.public_key().to_base64();
let text = match (diff.old, diff.new) {
(None, Some(_)) => "added".green(),
(Some(_), Some(_)) => "modified".yellow(),
(Some(_), None) => "removed".red(),
_ => unreachable!("PeerDiff can't be None -> None"),
};
// Grab the peer name from either the new data, or the historical data (if the peer is removed).
let peer_hostname = match diff.new {
Some(peer) => Some(peer.name.clone()),
None => store
.peers()
.iter()
.find(|p| p.public_key == public_key)
.map(|p| p.name.clone()),
};
let peer_name = peer_hostname.as_deref().unwrap_or("[unknown]");
log::info!(
" peer {} ({}...) was {}.",
peer_name.yellow(),
&public_key[..10].dimmed(),
text
);
for change in diff.changes() {
log::debug!(" {}", change);
}
}
pub struct Api<'a> {
agent: Agent,
server: &'a ServerInfo,
@ -151,7 +184,7 @@ impl<'a> Api<'a> {
Self { agent, server }
}
pub fn http<T: DeserializeOwned>(&self, verb: &str, endpoint: &str) -> Result<T, Error> {
pub fn http<T: DeserializeOwned>(&self, verb: &str, endpoint: &str) -> Result<T, ureq::Error> {
self.request::<(), _>(verb, endpoint, None)
}
@ -160,7 +193,7 @@ impl<'a> Api<'a> {
verb: &str,
endpoint: &str,
form: S,
) -> Result<T, Error> {
) -> Result<T, ureq::Error> {
self.request(verb, endpoint, Some(form))
}
@ -169,7 +202,7 @@ impl<'a> Api<'a> {
verb: &str,
endpoint: &str,
form: Option<S>,
) -> Result<T, Error> {
) -> Result<T, ureq::Error> {
let request = self
.agent
.request(
@ -179,7 +212,12 @@ impl<'a> Api<'a> {
.set(INNERNET_PUBKEY_HEADER, &self.server.public_key);
let response = if let Some(form) = form {
request.send_json(serde_json::to_value(form)?)?
request.send_json(serde_json::to_value(form).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("failed to serialize JSON request: {}", e),
)
})?)?
} else {
request.call()?
};
@ -190,10 +228,13 @@ impl<'a> Api<'a> {
response = "null".into();
}
Ok(serde_json::from_str(&response).map_err(|e| {
ClientError(format!(
"failed to deserialize JSON response from the server: {}, response={}",
e, &response
))
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"failed to deserialize JSON response from the server: {}, response={}",
e, &response
),
)
})?)
}
}

View File

@ -36,11 +36,20 @@ pub async fn routes(
let form = form_body(req).await?;
handlers::endpoint(form, session).await
},
(&Method::PUT, Some("candidates")) => {
if !session.user_capable() {
return Err(ServerError::Unauthorized);
}
let form = form_body(req).await?;
handlers::candidates(form, session).await
},
_ => Err(ServerError::NotFound),
}
}
mod handlers {
use shared::Endpoint;
use super::*;
/// Get the current state of the network, in the eyes of the current peer.
@ -115,14 +124,29 @@ mod handlers {
status_response(StatusCode::NO_CONTENT)
}
/// Redeems an invitation. An invitation includes a WireGuard keypair generated by either the server
/// or a peer with admin rights.
///
/// Redemption is the process of an invitee generating their own keypair and exchanging their temporary
/// key with their permanent one.
///
/// Until this API endpoint is called, the invited peer will not show up to other peers, and once
/// it is called and succeeds, it cannot be called again.
/// Report any other endpoint candidates that can be tried by peers to connect.
/// Currently limited to 10 candidates max.
pub async fn candidates(
contents: Vec<Endpoint>,
session: Session,
) -> Result<Response<Body>, ServerError> {
if contents.len() > 10 {
return status_response(StatusCode::PAYLOAD_TOO_LARGE);
}
let conn = session.context.db.lock();
let mut selected_peer = DatabasePeer::get(&conn, session.peer.id)?;
selected_peer.update(
&conn,
PeerContents {
candidates: contents,
..selected_peer.contents.clone()
},
)?;
status_response(StatusCode::NO_CONTENT)
}
/// Force a specific endpoint to be reported by the server.
pub async fn endpoint(
contents: EndpointContents,
session: Session,
@ -148,7 +172,7 @@ mod tests {
use super::*;
use crate::{db::DatabaseAssociation, test};
use bytes::Buf;
use shared::{AssociationContents, CidrContents, EndpointContents, Error};
use shared::{AssociationContents, CidrContents, Endpoint, EndpointContents, Error};
#[tokio::test]
async fn test_get_state_from_developer1() -> Result<(), Error> {
@ -406,4 +430,36 @@ mod tests {
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
Ok(())
}
#[tokio::test]
async fn test_candidates() -> Result<(), Error> {
let server = test::Server::new()?;
let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?;
assert_eq!(peer.candidates, vec![]);
let candidates = vec!["1.1.1.1:51820".parse::<Endpoint>().unwrap()];
assert_eq!(
server
.form_request(
test::DEVELOPER1_PEER_IP,
"PUT",
"/v1/user/candidates",
&candidates
)
.await
.status(),
StatusCode::NO_CONTENT
);
let res = server
.request(test::DEVELOPER1_PEER_IP, "GET", "/v1/user/state")
.await;
assert_eq!(res.status(), StatusCode::OK);
let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?;
assert_eq!(peer.candidates, candidates);
Ok(())
}
}

View File

@ -8,11 +8,13 @@ pub use peer::DatabasePeer;
use rusqlite::params;
const INVITE_EXPIRATION_VERSION: usize = 1;
const ENDPOINT_CANDIDATES_VERSION: usize = 2;
pub const CURRENT_VERSION: usize = INVITE_EXPIRATION_VERSION;
pub const CURRENT_VERSION: usize = ENDPOINT_CANDIDATES_VERSION;
pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
let old_version: usize = conn.pragma_query_value(None, "user_version", |r| r.get(0))?;
log::debug!("user_version: {}", old_version);
if old_version < INVITE_EXPIRATION_VERSION {
conn.execute(
@ -21,8 +23,12 @@ pub fn auto_migrate(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error>
)?;
}
conn.pragma_update(None, "user_version", &CURRENT_VERSION)?;
if old_version < ENDPOINT_CANDIDATES_VERSION {
conn.execute("ALTER TABLE peers ADD COLUMN candidates TEXT", params![])?;
}
if old_version != CURRENT_VERSION {
conn.pragma_update(None, "user_version", &CURRENT_VERSION)?;
log::info!(
"migrated db version from {} to {}",
old_version,

View File

@ -2,8 +2,8 @@ use super::DatabaseCidr;
use crate::ServerError;
use lazy_static::lazy_static;
use regex::Regex;
use rusqlite::{params, Connection};
use shared::{Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS};
use rusqlite::{params, types::Type, Connection};
use shared::{Endpoint, Peer, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS};
use std::{
net::IpAddr,
ops::{Deref, DerefMut},
@ -22,12 +22,27 @@ pub static CREATE_TABLE_SQL: &str = "CREATE TABLE peers (
is_disabled INTEGER DEFAULT 0 NOT NULL, /* Is the peer disabled? (peers cannot be deleted) */
is_redeemed INTEGER DEFAULT 0 NOT NULL, /* Has the peer redeemed their invite yet? */
invite_expires INTEGER, /* The UNIX time that an invited peer can no longer redeem. */
candidates TEXT, /* A list of additional endpoints that peers can use to connect. */
FOREIGN KEY (cidr_id)
REFERENCES cidrs (id)
ON UPDATE RESTRICT
ON DELETE RESTRICT
)";
pub static COLUMNS: &[&str] = &[
"id",
"name",
"ip",
"cidr_id",
"public_key",
"endpoint",
"is_admin",
"is_disabled",
"is_redeemed",
"invite_expires",
"candidates",
];
lazy_static! {
/// Regex to match the requirements of hostname(7), needed to have peers also be reachable hostnames.
/// Note that the full length also must be maximum 63 characters, which this regex does not check.
@ -71,6 +86,7 @@ impl DatabasePeer {
is_disabled,
is_redeemed,
invite_expires,
candidates,
..
} = &contents;
log::info!("creating peer {:?}", contents);
@ -96,8 +112,13 @@ impl DatabasePeer {
.flatten()
.map(|t| t.as_secs());
let candidates = serde_json::to_string(candidates)?;
conn.execute(
"INSERT INTO peers (name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
&format!(
"INSERT INTO peers ({}) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
COLUMNS[1..].join(", ")
),
params![
&**name,
ip.to_string(),
@ -108,6 +129,7 @@ impl DatabasePeer {
is_disabled,
is_redeemed,
invite_expires,
candidates,
],
)?;
let id = conn.last_insert_rowid();
@ -135,17 +157,21 @@ impl DatabasePeer {
endpoint: contents.endpoint,
is_admin: contents.is_admin,
is_disabled: contents.is_disabled,
candidates: contents.candidates,
..self.contents.clone()
};
let new_candidates = serde_json::to_string(&new_contents.candidates)?;
conn.execute(
"UPDATE peers SET
name = ?1,
endpoint = ?2,
is_admin = ?3,
is_disabled = ?4
WHERE id = ?5",
name = ?2,
endpoint = ?3,
is_admin = ?4,
is_disabled = ?5,
candidates = ?6
WHERE id = ?1",
params![
self.id,
&*new_contents.name,
new_contents
.endpoint
@ -153,7 +179,7 @@ impl DatabasePeer {
.map(|endpoint| endpoint.to_string()),
new_contents.is_admin,
new_contents.is_disabled,
self.id,
new_candidates,
],
)?;
@ -198,11 +224,11 @@ impl DatabasePeer {
let name = row
.get::<_, String>(1)?
.parse()
.map_err(|_| rusqlite::Error::ExecuteReturnedResults)?;
.map_err(|_| rusqlite::Error::InvalidColumnType(1, "hostname".into(), Type::Text))?;
let ip: IpAddr = row
.get::<_, String>(2)?
.parse()
.map_err(|_| rusqlite::Error::ExecuteReturnedResults)?;
.map_err(|_| rusqlite::Error::InvalidColumnType(2, "ip".into(), Type::Text))?;
let cidr_id = row.get(3)?;
let public_key = row.get(4)?;
let endpoint = row
@ -214,6 +240,10 @@ impl DatabasePeer {
let invite_expires = row
.get::<_, Option<u64>>(9)?
.map(|unixtime| SystemTime::UNIX_EPOCH + Duration::from_secs(unixtime));
let candidates_str: String = row.get(10)?;
let candidates: Vec<Endpoint> = serde_json::from_str(&candidates_str).map_err(|_| {
rusqlite::Error::InvalidColumnType(10, "candidates (json)".into(), Type::Text)
})?;
let persistent_keepalive_interval = Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS);
@ -230,6 +260,7 @@ impl DatabasePeer {
is_disabled,
is_redeemed,
invite_expires,
candidates,
},
}
.into())
@ -237,10 +268,7 @@ impl DatabasePeer {
pub fn get(conn: &Connection, id: i64) -> Result<Self, ServerError> {
let result = conn.query_row(
"SELECT
id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires
FROM peers
WHERE id = ?1",
&format!("SELECT {} FROM peers WHERE id = ?1", COLUMNS.join(", ")),
params![id],
Self::from_row,
)?;
@ -250,10 +278,7 @@ impl DatabasePeer {
pub fn get_from_ip(conn: &Connection, ip: IpAddr) -> Result<Self, rusqlite::Error> {
let result = conn.query_row(
"SELECT
id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires
FROM peers
WHERE ip = ?1",
&format!("SELECT {} FROM peers WHERE ip = ?1", COLUMNS.join(", ")),
params![ip.to_string()],
Self::from_row,
)?;
@ -271,7 +296,7 @@ impl DatabasePeer {
//
// NOTE that a forced association is created with the special "infra" CIDR with id 2 (1 being the root).
let mut stmt = conn.prepare_cached(
"WITH
&format!("WITH
parent_of(id, parent) AS (
SELECT id, parent FROM cidrs WHERE id = ?1
UNION ALL
@ -289,10 +314,12 @@ impl DatabasePeer {
UNION
SELECT id FROM cidrs, associated_subcidrs WHERE cidrs.parent=associated_subcidrs.cidr_id
)
SELECT DISTINCT peers.id, peers.name, peers.ip, peers.cidr_id, peers.public_key, peers.endpoint, peers.is_admin, peers.is_disabled, peers.is_redeemed, peers.invite_expires
SELECT DISTINCT {}
FROM peers
JOIN associated_subcidrs ON peers.cidr_id=associated_subcidrs.cidr_id
WHERE peers.is_disabled = 0 AND peers.is_redeemed = 1;",
COLUMNS.iter().map(|col| format!("peers.{}", col)).collect::<Vec<_>>().join(", ")
),
)?;
let peers = stmt
.query_map(params![self.cidr_id], Self::from_row)?
@ -301,9 +328,7 @@ impl DatabasePeer {
}
pub fn list(conn: &Connection) -> Result<Vec<Self>, ServerError> {
let mut stmt = conn.prepare_cached(
"SELECT id, name, ip, cidr_id, public_key, endpoint, is_admin, is_disabled, is_redeemed, invite_expires FROM peers",
)?;
let mut stmt = conn.prepare_cached(&format!("SELECT {} FROM peers", COLUMNS.join(", ")))?;
let peer_iter = stmt.query_map(params![], Self::from_row)?;
Ok(peer_iter.collect::<Result<_, _>>()?)

View File

@ -17,6 +17,7 @@ fn create_database<P: AsRef<Path>>(
conn.execute(db::association::CREATE_TABLE_SQL, params![])?;
conn.execute(db::cidr::CREATE_TABLE_SQL, params![])?;
conn.pragma_update(None, "user_version", &db::CURRENT_VERSION)?;
log::debug!("set database version to db::CURRENT_VERSION");
Ok(conn)
}
@ -89,6 +90,7 @@ fn populate_database(conn: &Connection, db_init_data: DbInitData) -> Result<(),
is_redeemed: true,
persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS),
invite_expires: None,
candidates: vec![],
},
)
.map_err(|_| anyhow!("failed to create innernet peer."))?;

View File

@ -476,9 +476,11 @@ async fn serve(
network: NetworkOpt,
) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(&interface))?;
log::debug!("opening database connection...");
let conn = open_database_connection(&interface, conf)?;
let peers = DatabasePeer::list(&conn)?;
log::debug!("peers listed...");
let peer_configs = peers
.iter()
.map(|peer| peer.deref().into())

View File

@ -231,6 +231,7 @@ pub fn peer_contents(
is_disabled: false,
is_redeemed: true,
invite_expires: None,
candidates: vec![],
})
}

View File

@ -29,3 +29,6 @@ netlink-sys = "0.7"
netlink-packet-core = "0.2"
netlink-packet-route = "0.7"
wgctrl-sys = { path = "../wgctrl-sys" }
[target.'cfg(target_os = "macos")'.dependencies]
nix = "0.22"

View File

@ -3,11 +3,14 @@ use netlink_packet_core::{
NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST,
};
use netlink_packet_route::{
address, constants::*, link, route, AddressHeader, AddressMessage, LinkHeader, LinkMessage,
RouteHeader, RouteMessage, RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN,
address,
constants::*,
link::{self, nlas::State},
route, AddressHeader, AddressMessage, LinkHeader, LinkMessage, RouteHeader, RouteMessage,
RtnlMessage, RTN_UNICAST, RT_SCOPE_LINK, RT_TABLE_MAIN,
};
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
use std::io;
use std::{io, net::IpAddr};
use wgctrl::InterfaceName;
fn if_nametoindex(interface: &InterfaceName) -> Result<u32, io::Error> {
@ -23,7 +26,7 @@ fn if_nametoindex(interface: &InterfaceName) -> Result<u32, io::Error> {
fn netlink_call(
message: RtnlMessage,
flags: Option<u16>,
) -> Result<NetlinkMessage<RtnlMessage>, io::Error> {
) -> Result<Vec<NetlinkMessage<RtnlMessage>>, io::Error> {
let mut req = NetlinkMessage::from(message);
req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
req.finalize();
@ -32,10 +35,10 @@ fn netlink_call(
let len = req.buffer_len();
log::debug!("netlink request: {:?}", req);
let socket = Socket::new(NETLINK_ROUTE).unwrap();
let socket = Socket::new(NETLINK_ROUTE)?;
let kernel_addr = SocketAddr::new(0, 0);
socket.connect(&kernel_addr)?;
let n_sent = socket.send(&buf[..len], 0).unwrap();
let n_sent = socket.send(&buf[..len], 0)?;
if n_sent != len {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
@ -43,14 +46,31 @@ fn netlink_call(
));
}
let n_received = socket.recv(&mut buf[..], 0).unwrap();
let response = NetlinkMessage::<RtnlMessage>::deserialize(&buf[..n_received])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
log::trace!("netlink response: {:?}", response);
if let NetlinkPayload::Error(e) = response.payload {
return Err(e.to_io());
let mut responses = vec![];
loop {
let n_received = socket.recv(&mut buf[..], 0)?;
let mut offset = 0;
loop {
let bytes = &buf[offset..];
let response = NetlinkMessage::<RtnlMessage>::deserialize(bytes)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
responses.push(response.clone());
log::trace!("netlink response: {:?}", response);
match response.payload {
// We've parsed all parts of the response and can leave the loop.
NetlinkPayload::Ack(_) | NetlinkPayload::Done => return Ok(responses),
NetlinkPayload::Error(e) => return Err(e.into()),
_ => {},
}
offset += response.header.length as usize;
if offset == n_received || response.header.length == 0 {
// We've fully parsed the datagram, but there may be further datagrams
// with additional netlink response parts.
log::debug!("breaking inner loop");
break;
}
}
}
Ok(response)
}
pub fn set_up(interface: &InterfaceName, mtu: u32) -> Result<(), io::Error> {
@ -127,3 +147,87 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result<bool, io:
Err(e) => Err(e),
}
}
fn get_links() -> Result<Vec<String>, io::Error> {
let link_responses = netlink_call(
RtnlMessage::GetLink(LinkMessage::default()),
Some(NLM_F_DUMP | NLM_F_REQUEST),
)?;
let links = link_responses
.into_iter()
// Filter out non-link messages
.filter_map(|response| match response {
NetlinkMessage {
payload: NetlinkPayload::InnerMessage(RtnlMessage::NewLink(link)),
..
} => Some(link),
_ => None,
})
// Filter out loopback links
.filter_map(|link| if link.header.flags & IFF_LOOPBACK == 0 {
Some(link.nlas)
} else {
None
})
// Find and filter out addresses for interfaces
.filter(|nlas| nlas.iter().any(|nla| nla == &link::nlas::Nla::OperState(State::Up)))
.filter_map(|nlas| nlas.iter().find_map(|nla| match nla {
link::nlas::Nla::IfName(name) => Some(name.clone()),
_ => None,
}))
.collect::<Vec<_>>();
Ok(links)
}
pub fn get_local_addrs() -> Result<Vec<IpAddr>, io::Error> {
let links = get_links()?;
let addr_responses = netlink_call(
RtnlMessage::GetAddress(AddressMessage::default()),
Some(NLM_F_DUMP | NLM_F_REQUEST),
)?;
let addrs = addr_responses
.into_iter()
// Filter out non-link messages
.filter_map(|response| match response {
NetlinkMessage {
payload: NetlinkPayload::InnerMessage(RtnlMessage::NewAddress(addr)),
..
} => Some(addr),
_ => None,
})
// Filter out non-global-scoped addresses
.filter_map(|link| if link.header.scope == RT_SCOPE_UNIVERSE {
Some(link.nlas)
} else {
None
})
// Only select addresses for helpful links
.filter(|nlas| nlas.iter().any(|nla| matches!(nla, address::nlas::Nla::Label(label) if links.contains(label))))
.filter_map(|nlas| nlas.iter().find_map(|nla| match nla {
address::nlas::Nla::Address(name) if name.len() == 4 => {
let mut addr = [0u8; 4];
addr.copy_from_slice(name);
Some(IpAddr::V4(addr.into()))
},
address::nlas::Nla::Address(name) if name.len() == 16 => {
let mut addr = [0u8; 16];
addr.copy_from_slice(name);
Some(IpAddr::V6(addr.into()))
},
_ => None,
}))
.collect::<Vec<_>>();
Ok(addrs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_local_addrs() {
let addrs = get_local_addrs().unwrap();
println!("{:?}", addrs);
}
}

View File

@ -285,6 +285,7 @@ pub fn add_peer(
is_redeemed: false,
persistent_keepalive_interval: Some(PERSISTENT_KEEPALIVE_INTERVAL_SECS),
invite_expires: Some(SystemTime::now() + invite_expires.into()),
candidates: vec![],
};
Ok(

View File

@ -7,7 +7,7 @@ use std::{
fmt::{self, Display, Formatter},
io,
net::{IpAddr, SocketAddr, ToSocketAddrs},
ops::Deref,
ops::{Deref, DerefMut},
path::Path,
str::FromStr,
time::{Duration, SystemTime},
@ -20,6 +20,8 @@ use wgctrl::{
PeerInfo,
};
use crate::wg::PeerInfoExt;
#[derive(Debug, Clone)]
pub struct Interface {
name: InterfaceName,
@ -408,6 +410,8 @@ pub struct PeerContents {
pub is_disabled: bool,
pub is_redeemed: bool,
pub invite_expires: Option<SystemTime>,
#[serde(default)]
pub candidates: Vec<Endpoint>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
@ -426,6 +430,12 @@ impl Deref for Peer {
}
}
impl DerefMut for Peer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.contents
}
}
impl Display for Peer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} ({})", &self.name, &self.public_key)
@ -497,19 +507,6 @@ impl<'a> PeerDiff<'a> {
}
}
/// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this
/// as a heuristic for "currentness" without relying on heavier things like ICMP.
pub fn peer_recently_connected(peer: &Option<&PeerInfo>) -> bool {
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
let last_handshake = peer
.and_then(|p| p.stats.last_handshake_time)
.and_then(|t| t.elapsed().ok())
.unwrap_or_else(|| SystemTime::UNIX_EPOCH.elapsed().unwrap());
last_handshake <= REJECT_AFTER_TIME
}
pub fn public_key(&self) -> &Key {
self.builder.public_key()
}
@ -570,7 +567,10 @@ impl<'a> PeerDiff<'a> {
}
// We won't update the endpoint if there's already a stable connection.
if !Self::peer_recently_connected(&old_info) {
if !old_info
.map(|info| info.is_recently_connected())
.unwrap_or_default()
{
let resolved = new.endpoint.as_ref().and_then(|e| e.resolve().ok());
if let Some(addr) = resolved {
if old.is_none() || matches!(old, Some(old) if old.endpoint != resolved) {
@ -772,6 +772,7 @@ mod tests {
is_disabled: false,
is_redeemed: true,
invite_expires: None,
candidates: vec![],
},
};
let builder =
@ -806,6 +807,7 @@ mod tests {
is_disabled: false,
is_redeemed: true,
invite_expires: None,
candidates: vec![],
},
};
let builder =
@ -840,6 +842,7 @@ mod tests {
is_disabled: false,
is_redeemed: true,
invite_expires: None,
candidates: vec![],
},
};
let builder =

View File

@ -1,10 +1,11 @@
use crate::{Error, IoErrorContext, NetworkOpt};
use crate::{Error, IoErrorContext, NetworkOpt, Peer, PeerDiff};
use ipnetwork::IpNetwork;
use std::{
io,
net::{IpAddr, SocketAddr},
time::Duration,
};
use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, PeerConfigBuilder};
use wgctrl::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerConfigBuilder, PeerInfo};
#[cfg(target_os = "macos")]
fn cmd(bin: &str, args: &[&str]) -> Result<std::process::Output, io::Error> {
@ -164,3 +165,97 @@ pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result<bool, io:
#[cfg(target_os = "linux")]
pub use super::netlink::add_route;
#[cfg(target_os = "macos")]
pub fn get_local_addrs() -> Result<Vec<IpAddr>, io::Error> {
use nix::{net::if_::InterfaceFlags, sys::socket::SockAddr};
let addrs = nix::ifaddrs::getifaddrs()?
.inspect(|addr| println!("{:?}", addr))
.filter(|addr| {
addr.flags.contains(InterfaceFlags::IFF_UP)
&& !addr.flags.intersects(
InterfaceFlags::IFF_LOOPBACK
| InterfaceFlags::IFF_POINTOPOINT
| InterfaceFlags::IFF_PROMISC,
)
})
.filter_map(|addr| match addr.address {
Some(SockAddr::Inet(addr)) if addr.to_std().is_ipv4() => Some(addr.to_std().ip()),
_ => None,
})
.collect::<Vec<_>>();
Ok(addrs)
}
#[cfg(target_os = "linux")]
pub use super::netlink::get_local_addrs;
pub trait DeviceExt {
/// Diff the output of a wgctrl device with a list of server-reported peers.
fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec<PeerDiff<'a>>;
// /// Get a peer by their public key, a helper function.
fn get_peer(&self, public_key: &str) -> Option<&PeerInfo>;
}
impl DeviceExt for Device {
fn diff<'a>(&'a self, peers: &'a [Peer]) -> Vec<PeerDiff<'a>> {
let interface_public_key = self
.public_key
.as_ref()
.map(|k| k.to_base64())
.unwrap_or_default();
let existing_peers = &self.peers;
// Match existing peers (by pubkey) to new peer information from the server.
let modifications = peers.iter().filter_map(|peer| {
if peer.is_disabled || peer.public_key == interface_public_key {
None
} else {
let existing_peer = existing_peers
.iter()
.find(|p| p.config.public_key.to_base64() == peer.public_key);
PeerDiff::new(existing_peer, Some(peer)).unwrap()
}
});
// Remove any peers on the interface that aren't in the server's peer list any more.
let removals = existing_peers.iter().filter_map(|existing| {
let public_key = existing.config.public_key.to_base64();
if peers.iter().any(|p| p.public_key == public_key) {
None
} else {
PeerDiff::new(Some(existing), None).unwrap()
}
});
modifications.chain(removals).collect::<Vec<_>>()
}
fn get_peer(&self, public_key: &str) -> Option<&PeerInfo> {
Key::from_base64(public_key)
.ok()
.and_then(|key| self.peers.iter().find(|peer| peer.config.public_key == key))
}
}
pub trait PeerInfoExt {
/// WireGuard rejects any communication after REJECT_AFTER_TIME, so we can use this
/// as a heuristic for "currentness" without relying on heavier things like ICMP.
fn is_recently_connected(&self) -> bool;
}
impl PeerInfoExt for PeerInfo {
fn is_recently_connected(&self) -> bool {
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
let last_handshake = self
.stats
.last_handshake_time
.and_then(|t| t.elapsed().ok())
.unwrap_or(Duration::MAX);
last_handshake <= REJECT_AFTER_TIME
}
}