diff --git a/Cargo.lock b/Cargo.lock index c40426b..ea51143 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1138,6 +1138,7 @@ dependencies = [ "tokio", "toml", "ureq", + "url", "warp", "wgctrl", ] @@ -1169,6 +1170,7 @@ dependencies = [ "structopt", "toml", "ureq", + "url", "wgctrl", ] diff --git a/client/src/main.rs b/client/src/main.rs index 1ddc74b..15929bf 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -283,6 +283,7 @@ fn redeem_invite( target_conf: PathBuf, ) -> Result<(), Error> { println!("{} bringing up the interface.", "[*]".dimmed()); + let resolved_endpoint = config.server.external_endpoint.resolve()?; wg::up( &iface, &config.interface.private_key, @@ -291,7 +292,7 @@ fn redeem_invite( Some(( &config.server.public_key, config.server.internal_endpoint.ip(), - config.server.external_endpoint, + resolved_endpoint, )), )?; @@ -369,6 +370,7 @@ fn fetch( } println!("{} bringing up the interface.", "[*]".dimmed()); + let resolved_endpoint = config.server.external_endpoint.resolve()?; wg::up( interface, &config.interface.private_key, @@ -377,7 +379,7 @@ fn fetch( Some(( &config.server.public_key, config.server.internal_endpoint.ip(), - config.server.external_endpoint, + resolved_endpoint, )), )? } @@ -821,7 +823,7 @@ fn print_peer(our_peer: &Peer, peer: &PeerInfo, short: bool) -> Result<(), Error &our_peer.public_key[..10].yellow() ); println!(" {}: {}", "ip".bold(), our_peer.ip); - if let Some(endpoint) = our_peer.endpoint { + if let Some(ref endpoint) = our_peer.endpoint { println!(" {}: {}", "endpoint".bold(), endpoint); } if let Some(last_handshake) = peer.stats.last_handshake_time { diff --git a/server/Cargo.toml b/server/Cargo.toml index fb261d7..5b45c42 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -34,6 +34,7 @@ subtle = "2" structopt = "0.3" thiserror = "1" ureq = { version = "2", default-features = false } +url = "2" tokio = { version = "1", features = ["macros", "rt-multi-thread"] } toml = "0.5" warp = { git = "https://github.com/tonarino/warp", default-features = false } # pending https://github.com/seanmonstar/warp/issues/830 diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 144d68c..41b17c8 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -11,7 +11,7 @@ pub fn inject_endpoints(session: &Session, peers: &mut Vec) { for mut peer in peers { if peer.contents.endpoint.is_none() { if let Some(endpoint) = session.context.endpoints.get(&peer.public_key) { - peer.contents.endpoint = Some(endpoint.to_owned()); + peer.contents.endpoint = Some(endpoint.to_owned().into()); } } } diff --git a/server/src/api/user.rs b/server/src/api/user.rs index b7c5929..dd296f2 100644 --- a/server/src/api/user.rs +++ b/server/src/api/user.rs @@ -200,7 +200,7 @@ mod tests { .put_request_from_ip(test::DEVELOPER1_PEER_IP) .path("/v1/user/endpoint") .body(serde_json::to_string(&EndpointContents::Set( - "1.1.1.1:51820".parse()? + "1.1.1.1:51820".parse().unwrap() ))?) .reply(&filter) .await diff --git a/server/src/db/peer.rs b/server/src/db/peer.rs index 1d695ca..517282c 100644 --- a/server/src/db/peer.rs +++ b/server/src/db/peer.rs @@ -95,7 +95,7 @@ impl DatabasePeer { ip.to_string(), cidr_id, &public_key, - endpoint.map(|endpoint| endpoint.to_string()), + endpoint.as_ref().map(|endpoint| endpoint.to_string()), is_admin, is_disabled, is_redeemed, @@ -138,7 +138,10 @@ impl DatabasePeer { WHERE id = ?5", params![ new_contents.name, - new_contents.endpoint.map(|endpoint| endpoint.to_string()), + new_contents + .endpoint + .as_ref() + .map(|endpoint| endpoint.to_string()), new_contents.is_admin, new_contents.is_disabled, self.id, diff --git a/server/src/initialize.rs b/server/src/initialize.rs index 24c6b80..8f71797 100644 --- a/server/src/initialize.rs +++ b/server/src/initialize.rs @@ -5,7 +5,7 @@ use indoc::printdoc; use rusqlite::{params, Connection}; use shared::{ prompts::{self, hostname_validator}, - CidrContents, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS, + CidrContents, Endpoint, PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS, }; use wgctrl::KeyPair; @@ -32,7 +32,7 @@ pub struct InitializeOpts { /// This server's external endpoint (ex: 100.100.100.100:51820) #[structopt(long, conflicts_with = "auto-external-endpoint")] - pub external_endpoint: Option, + pub external_endpoint: Option, /// Auto-resolve external endpoint #[structopt(long = "auto-external-endpoint")] @@ -49,7 +49,7 @@ struct DbInitData { server_cidr: IpNetwork, our_ip: IpAddr, public_key_base64: String, - endpoint: SocketAddr, + endpoint: Endpoint, } fn populate_database(conn: &Connection, db_init_data: DbInitData) -> Result<(), Error> { @@ -126,7 +126,7 @@ pub fn init_wizard(conf: &ServerConfig, opts: InitializeOpts) -> Result<(), Erro // This probably won't error because of the `hostname_validator` regex. let name = name.parse()?; - let endpoint: SocketAddr = if let Some(endpoint) = opts.external_endpoint { + let endpoint: Endpoint = if let Some(endpoint) = opts.external_endpoint { endpoint.clone() } else { let external_ip: Option = ureq::get("http://4.icanhazip.com") @@ -139,7 +139,7 @@ pub fn init_wizard(conf: &ServerConfig, opts: InitializeOpts) -> Result<(), Erro if opts.auto_external_endpoint { let ip = external_ip.ok_or("couldn't get external IP")?; - (ip, 51820).into() + SocketAddr::new(ip, 51820).into() } else { prompts::ask_endpoint(external_ip)? } diff --git a/server/src/test.rs b/server/src/test.rs index 62430dd..88948e4 100644 --- a/server/src/test.rs +++ b/server/src/test.rs @@ -69,7 +69,7 @@ impl Server { let opts = InitializeOpts { network_name: Some(interface.clone()), network_cidr: Some(ROOT_CIDR.parse()?), - external_endpoint: Some("155.155.155.155:54321".parse()?), + external_endpoint: Some("155.155.155.155:54321".parse().unwrap()), listen_port: Some(54321), auto_external_endpoint: false, }; diff --git a/shared/Cargo.toml b/shared/Cargo.toml index 1a7a5a4..0d943cd 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -17,4 +17,5 @@ serde = { version = "1", features = ["derive"] } structopt = "0.3" toml = "0.5" ureq = { version = "2", default-features = false } +url = "2" wgctrl = { path = "../wgctrl-rs" } diff --git a/shared/src/interface_config.rs b/shared/src/interface_config.rs index dd5b23c..10febaa 100644 --- a/shared/src/interface_config.rs +++ b/shared/src/interface_config.rs @@ -1,4 +1,4 @@ -use crate::{ensure_dirs_exist, Error, IoErrorContext, CLIENT_CONFIG_PATH}; +use crate::{ensure_dirs_exist, Endpoint, Error, IoErrorContext, CLIENT_CONFIG_PATH}; use colored::*; use indoc::writedoc; use ipnetwork::IpNetwork; @@ -46,7 +46,7 @@ pub struct ServerInfo { pub public_key: String, /// The external internet endpoint to reach the server. - pub external_endpoint: SocketAddr, + pub external_endpoint: Endpoint, /// An internal endpoint in the WireGuard network that hosts the coordination API. pub internal_endpoint: SocketAddr, diff --git a/shared/src/prompts.rs b/shared/src/prompts.rs index e95f854..924f194 100644 --- a/shared/src/prompts.rs +++ b/shared/src/prompts.rs @@ -1,7 +1,7 @@ use crate::{ interface_config::{InterfaceConfig, InterfaceInfo, ServerInfo}, - AddCidrOpts, AddPeerOpts, Association, Cidr, CidrContents, CidrTree, Error, Peer, PeerContents, - PERSISTENT_KEEPALIVE_INTERVAL_SECS, + AddCidrOpts, AddPeerOpts, Association, Cidr, CidrContents, CidrTree, Endpoint, Error, Peer, + PeerContents, PERSISTENT_KEEPALIVE_INTERVAL_SECS, }; use colored::*; use dialoguer::{theme::ColorfulTheme, Confirm, Input, Select}; @@ -299,6 +299,7 @@ pub fn save_peer_invitation( server: ServerInfo { external_endpoint: server_peer .endpoint + .clone() .expect("The innernet server should have a WireGuard endpoint"), internal_endpoint: *server_api_addr, public_key: server_peer.public_key.clone(), @@ -362,7 +363,7 @@ pub fn set_listen_port( } } -pub fn ask_endpoint(external_ip: Option) -> Result { +pub fn ask_endpoint(external_ip: Option) -> Result { println!("getting external IP address."); let external_ip = if external_ip.is_some() { @@ -379,7 +380,7 @@ pub fn ask_endpoint(external_ip: Option) -> Result { let mut endpoint_builder = Input::with_theme(&*THEME); if let Some(ip) = external_ip { - endpoint_builder.default(SocketAddr::new(ip, 51820)); + endpoint_builder.default(SocketAddr::new(ip, 51820).into()); } else { println!("failed to get external IP."); } @@ -389,7 +390,7 @@ pub fn ask_endpoint(external_ip: Option) -> Result { .map_err(|e| Error::from(e)) } -pub fn override_endpoint(unset: bool) -> Result>, Error> { +pub fn override_endpoint(unset: bool) -> Result>, Error> { let endpoint = if !unset { Some(ask_endpoint(None)?) } else { diff --git a/shared/src/types.rs b/shared/src/types.rs index 3e6c952..4d2f146 100644 --- a/shared/src/types.rs +++ b/shared/src/types.rs @@ -2,13 +2,15 @@ use crate::prompts::hostname_validator; use ipnetwork::IpNetwork; use serde::{Deserialize, Serialize}; use std::{ - fmt::{Display, Formatter}, - net::{IpAddr, SocketAddr}, + fmt::{self, Display, Formatter}, + net::{IpAddr, SocketAddr, ToSocketAddrs}, ops::Deref, path::Path, str::FromStr, + vec, }; use structopt::StructOpt; +use url::Host; use wgctrl::{InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder}; #[derive(Debug, Clone)] @@ -37,15 +39,105 @@ impl Deref for Interface { } } +#[derive(Clone, Debug, PartialEq)] +/// An external endpoint that supports both IP and domain name hosts. +pub struct Endpoint { + host: Host, + port: u16, +} + +impl From for Endpoint { + fn from(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(v4addr) => Self { + host: Host::Ipv4(*v4addr.ip()), + port: v4addr.port(), + }, + SocketAddr::V6(v6addr) => Self { + host: Host::Ipv6(*v6addr.ip()), + port: v6addr.port(), + }, + } + } +} + +impl FromStr for Endpoint { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + match s.rsplitn(2, ':').collect::>().as_slice() { + [port, host] => { + let port = port.parse().map_err(|_| "couldn't parse port")?; + let host = Host::parse(host).map_err(|_| "couldn't parse host")?; + Ok(Endpoint { host, port }) + }, + _ => Err("couldn't parse in form of 'host:port'"), + } + } +} + +impl Serialize for Endpoint { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Endpoint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct EndpointVisitor; + impl<'de> serde::de::Visitor<'de> for EndpointVisitor { + type Value = Endpoint; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a valid host:port endpoint") + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + s.parse().map_err(serde::de::Error::custom) + } + } + deserializer.deserialize_str(EndpointVisitor) + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.host.fmt(f)?; + f.write_str(":")?; + self.port.fmt(f) + } +} + +impl Endpoint { + pub fn resolve(&self) -> Result { + let mut addrs = self + .to_string() + .to_socket_addrs() + .map_err(|e| e.to_string())?; + addrs + .next() + .ok_or_else(|| "failed to resolve address".to_string()) + } +} + #[derive(Deserialize, Serialize, Debug)] #[serde(tag = "option", content = "content")] pub enum EndpointContents { - Set(SocketAddr), + Set(Endpoint), Unset, } -impl Into> for EndpointContents { - fn into(self) -> Option { +impl Into> for EndpointContents { + fn into(self) -> Option { match self { Self::Set(addr) => Some(addr), Self::Unset => None, @@ -53,8 +145,8 @@ impl Into> for EndpointContents { } } -impl From> for EndpointContents { - fn from(option: Option) -> Self { +impl From> for EndpointContents { + fn from(option: Option) -> Self { match option { Some(addr) => Self::Set(addr), None => Self::Unset, @@ -246,7 +338,7 @@ pub struct PeerContents { pub ip: IpAddr, pub cidr_id: i64, pub public_key: String, - pub endpoint: Option, + pub endpoint: Option, pub persistent_keepalive_interval: Option, pub is_admin: bool, pub is_disabled: bool, @@ -287,8 +379,11 @@ impl Peer { pub fn diff(&self, peer: &PeerConfig) -> Option { assert_eq!(self.public_key, peer.public_key.to_base64()); - let endpoint_diff = if peer.endpoint != self.endpoint { - self.endpoint + let endpoint_diff = if let Some(ref endpoint) = self.endpoint { + match endpoint.resolve() { + Ok(resolved) if Some(resolved) != peer.endpoint => Some(resolved), + _ => None, + } } else { None }; @@ -331,7 +426,9 @@ impl<'a> From<&'a Peer> for PeerConfigBuilder { builder }; - if let Some(endpoint) = peer.endpoint { + let resolved = peer.endpoint.as_ref().map(|e| e.resolve().ok()).flatten(); + + if let Some(endpoint) = resolved { builder.set_endpoint(endpoint) } else { builder