diff --git a/Cargo.lock b/Cargo.lock index e08c12a..87e735d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,9 +393,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfb77c123b4e2f72a2069aeae0b4b4949cc7e966df277813fc16347e7549737" +checksum = "60daa14be0e0786db0f03a9e57cb404c9d756eed2b6c62b9ea98ec5743ec75a9" dependencies = [ "bytes", "http", @@ -728,6 +728,10 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "publicip" +version = "0.1.0" + [[package]] name = "quick-error" version = "1.2.3" @@ -898,6 +902,7 @@ dependencies = [ "log", "parking_lot", "pretty_env_logger", + "publicip", "regex", "rusqlite", "serde", @@ -910,7 +915,6 @@ dependencies = [ "thiserror", "tokio", "toml", - "ureq", "url", "wgctrl", ] @@ -924,11 +928,11 @@ dependencies = [ "indoc", "ipnetwork", "lazy_static", + "publicip", "regex", "serde", "structopt", "toml", - "ureq", "url", "wgctrl", ] diff --git a/Cargo.toml b/Cargo.toml index 7eaf949..05e25f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["server", "client", "hostsfile", "shared"] +members = ["server", "client", "hostsfile", "shared", "publicip"] [profile.release] codegen-units = 1 diff --git a/publicip/Cargo.toml b/publicip/Cargo.toml new file mode 100644 index 0000000..ca7087a --- /dev/null +++ b/publicip/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "publicip" +version = "0.1.0" +authors = ["Jake McGinty "] +edition = "2018" + +[dependencies] diff --git a/publicip/src/lib.rs b/publicip/src/lib.rs new file mode 100644 index 0000000..2db4375 --- /dev/null +++ b/publicip/src/lib.rs @@ -0,0 +1,179 @@ +//! Get your public IP address(es) as fast as possible, with no dependencies. +//! +//! Currently uses Cloudflare's DNS as it's the simplest, but that could change +//! in the future. + +use std::{ + fs::File, + io::{Cursor, Error, ErrorKind, Read, Write}, + marker::PhantomData, + net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, + str::FromStr, + time::Duration, +}; + +macro_rules! ensure { + ($cond:expr, $msg:literal $(,)?) => { + if !$cond { + return Err(Error::new(ErrorKind::InvalidInput, $msg.to_string())); + } + }; +} + +const TYPE_TXT: u16 = 0x0010; // TXT-type requests (could also be A, AAAA, etc.) +const CLASS_CH: u16 = 0x0003; // Because we are in the chaos realm. + +static CLOUDFLARE_QNAME: &[&str] = &["whoami", "cloudflare"]; +const CLOUDFLARE_IPV4: Ipv4Addr = Ipv4Addr::new(1, 1, 1, 1); +const CLOUDFLARE_IPV6: Ipv6Addr = Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111); + +pub fn public_ip() -> Result<(Option, Option), Error> { + let ipv4 = Request::start(CLOUDFLARE_IPV4)?; + let ipv6 = Request::start(CLOUDFLARE_IPV6)?; + Ok((ipv4.read_response().ok(), ipv6.read_response().ok())) +} + +struct Request { + socket: UdpSocket, + id: [u8; 2], + buf: [u8; 1500], + _ip_type: PhantomData, +} + +impl + FromStr> Request { + fn start(resolver: T) -> Result { + let socket = UdpSocket::bind(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0))?; + socket.set_read_timeout(Some(Duration::from_millis(500)))?; + let endpoint = SocketAddr::new(resolver.into(), 53); + + let id = get_id()?; + let mut buf = [0u8; 1500]; + let mut cursor = Cursor::new(&mut buf[..]); + cursor.write_all(&id)?; + cursor.write_all(&0x0100u16.to_be_bytes())?; // Request type (query, in this case) + cursor.write_all(&0x0001u16.to_be_bytes())?; // Number of queries + cursor.write_all(&0x0000u16.to_be_bytes())?; // Number of responses + cursor.write_all(&0x0000u16.to_be_bytes())?; // Number of name server records + cursor.write_all(&0x0000u16.to_be_bytes())?; // Number of additional records + for atom in CLOUDFLARE_QNAME { + // Write the length of this atom followed by the string itself + cursor.write_all(&[atom.len() as u8])?; + cursor.write_all(atom.as_bytes())?; + } + // Finish the qname with a terminating byte (0-length atom). + cursor.write_all(&[0x00])?; + cursor.write_all(&TYPE_TXT.to_be_bytes())?; + cursor.write_all(&CLASS_CH.to_be_bytes())?; + + let len = cursor.position() as usize; + socket.connect(endpoint)?; + socket.send(&buf[..len])?; + + Ok(Self { + socket, + id, + buf, + _ip_type: PhantomData, + }) + } + + fn read_response(mut self) -> Result { + let len = self.socket.recv(&mut self.buf)?; + ensure!(self.buf[..2] == self.id, "question/answer IDs don't match"); + let response = &self.buf[..len]; + let mut buf = Cursor::new(response); + let _id = buf.read_u16()?; + + let flags = buf.read_u16()?; + ensure!(flags & 0x8000 != 0, "not a response"); + ensure!(flags & 0x000f == 0, "non-zero DNS error code"); + + let qd = buf.read_u16()?; + ensure!(qd <= 1, "unexpected number of questions"); + ensure!(buf.read_u16()? == 1, "unexpected number of answers"); + ensure!(buf.read_u16()? == 0, "unexpected NS value"); + ensure!(buf.read_u16()? == 0, "unexpected AR value"); + + // Skip past the query section, don't care. + if qd != 0 { + loop { + let len = buf.read_u8()?; + if len == 0 { + break; + } + buf.set_position(buf.position() + len as u64); + } + // Skip type and class information as well. + buf.set_position(buf.position() + 4); + } + + let qname_len = buf.read_u16()?; + // Ignore if it's a pointer, ignore if it's a normal QNAME... + if qname_len & 0xc000 != 0xc000 { + buf.set_position(buf.position() + qname_len as u64); + } + ensure!(buf.read_u16()? == TYPE_TXT, "answer is not TXT type"); + ensure!(buf.read_u16()? == CLASS_CH, "answer is not CH class"); + buf.set_position(buf.position() + 4); // Ignore TTL + + let data_len = buf.read_u16()? as usize; + let txt_len = buf.read_u8()? as usize; + ensure!(txt_len == data_len - 1, "unexpected txt and data lengths."); + + let start = buf.position() as usize; + let end = start + txt_len; + ensure!(response.len() >= end, "unexpected txt answer lengths"); + + let txt = std::str::from_utf8(&response[start..end]).ok(); + let answer = txt + .and_then(|txt| txt.parse::().ok()) + .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "TXT not IP address".to_string()))?; + + Ok(answer) + } +} + +/// DNS wants a random-ish ID to be generated per request. +fn get_id() -> Result<[u8; 2], Error> { + let mut id = [0u8; 2]; + File::open("/dev/urandom")?.read_exact(&mut id)?; + Ok(id) +} + +trait ReadExt { + fn read_u16(&mut self) -> Result; + fn read_u8(&mut self) -> Result; +} + +impl ReadExt for Cursor<&[u8]> { + fn read_u16(&mut self) -> Result { + let mut u16_buf = [0; 2]; + self.read_exact(&mut u16_buf)?; + Ok(u16::from_be_bytes(u16_buf)) + } + + fn read_u8(&mut self) -> Result { + let mut u8_buf = [0]; + self.read_exact(&mut u8_buf)?; + Ok(u8_buf[0]) + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use crate::*; + + #[test] + #[ignore] + fn it_works() -> Result<(), Error> { + let now = Instant::now(); + let (v4, v6) = public_ip()?; + println!("Done in {}ms", now.elapsed().as_millis()); + println!("v4: {:?}, v6: {:?}", v4, v6); + assert!(v4.is_some()); + assert!(v6.is_some()); + Ok(()) + } +} diff --git a/release.sh b/release.sh index c9cd604..987e021 100755 --- a/release.sh +++ b/release.sh @@ -17,7 +17,7 @@ done [ "$#" -eq 1 ] || die "usage: ./release.sh [patch|major|minor|rc]" git diff --quiet || die 'ERROR: git repo is dirty.' -cargo release "$1" --no-confirm --exclude "hostsfile" +cargo release "$1" --no-confirm --exclude "hostsfile,publicip" # re-stage the manpage commit and the cargo-release commit git reset --soft @~1 diff --git a/server/Cargo.toml b/server/Cargo.toml index 23914a2..16637a7 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -23,6 +23,7 @@ libc = "0.2" libsqlite3-sys = "0.22" log = "0.4" parking_lot = "0.11" +publicip = { path = "../publicip" } pretty_env_logger = "0.4" regex = { version = "1", default-features = false, features = ["std"] } rusqlite = "0.25" @@ -34,7 +35,6 @@ subtle = "2" thiserror = "1" tokio = { version = "1", features = ["macros", "rt-multi-thread", "time"] } toml = "0.5" -ureq = { version = "2", default-features = false } url = "2" wgctrl = { path = "../wgctrl-rs" } diff --git a/server/src/initialize.rs b/server/src/initialize.rs index 39c4422..ea12b1b 100644 --- a/server/src/initialize.rs +++ b/server/src/initialize.rs @@ -130,13 +130,8 @@ pub fn init_wizard(conf: &ServerConfig, opts: InitializeOpts) -> Result<(), Erro let endpoint: Endpoint = if let Some(endpoint) = opts.external_endpoint { endpoint } else { - let external_ip: Option = ureq::get("http://4.icanhazip.com") - .call() - .ok() - .map(|res| res.into_string().ok()) - .flatten() - .map(|body| body.trim().to_string()) - .and_then(|body| body.parse().ok()); + let (v4, v6) = publicip::public_ip()?; + let external_ip = v4.map(IpAddr::from).or(v6.map(IpAddr::from)); if opts.auto_external_endpoint { let ip = external_ip.ok_or("couldn't get external IP")?; diff --git a/shared/Cargo.toml b/shared/Cargo.toml index b280bef..4fa179f 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -12,10 +12,10 @@ dialoguer = "0.8" indoc = "1" ipnetwork = { git = "https://github.com/mcginty/ipnetwork" } # pending https://github.com/achanda/ipnetwork/pull/129 lazy_static = "1" +publicip = { path = "../publicip" } regex = "1" serde = { version = "1", features = ["derive"] } structopt = "0.3" toml = "0.5" -ureq = { version = "2", default-features = false } url = "2" wgctrl = { path = "../wgctrl-rs" } diff --git a/shared/src/prompts.rs b/shared/src/prompts.rs index e278db6..e319c8e 100644 --- a/shared/src/prompts.rs +++ b/shared/src/prompts.rs @@ -362,13 +362,8 @@ pub fn ask_endpoint(external_ip: Option) -> Result { let external_ip = if external_ip.is_some() { external_ip } else { - ureq::get("http://4.icanhazip.com") - .call() - .ok() - .map(|res| res.into_string().ok()) - .flatten() - .map(|body| body.trim().to_string()) - .and_then(|body| body.parse().ok()) + let (v4, v6) = publicip::public_ip()?; + v4.map(IpAddr::from).or(v6.map(IpAddr::from)) }; let mut endpoint_builder = Input::with_theme(&*THEME);