diff --git a/client/src/data_store.rs b/client/src/data_store.rs index 15550f9..b405a91 100644 --- a/client/src/data_store.rs +++ b/client/src/data_store.rs @@ -6,6 +6,7 @@ use std::{ io::{Read, Seek, SeekFrom, Write}, path::Path, }; +use wgctrl::InterfaceName; #[derive(Debug)] pub struct DataStore { @@ -38,19 +39,21 @@ impl DataStore { Ok(Self { file, contents }) } - fn _open(interface: &str, create: bool) -> Result { + fn _open(interface: &InterfaceName, create: bool) -> Result { ensure_dirs_exist(&[*CLIENT_DATA_PATH])?; Self::open_with_path( - CLIENT_DATA_PATH.join(interface).with_extension("json"), + CLIENT_DATA_PATH + .join(interface.to_string()) + .with_extension("json"), create, ) } - pub fn open(interface: &str) -> Result { + pub fn open(interface: &InterfaceName) -> Result { Self::_open(interface, false) } - pub fn open_or_create(interface: &str) -> Result { + pub fn open_or_create(interface: &InterfaceName) -> Result { Self::_open(interface, true) } diff --git a/client/src/main.rs b/client/src/main.rs index 64994cd..8def1b3 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -14,7 +14,7 @@ use std::{ time::Duration, }; use structopt::StructOpt; -use wgctrl::{DeviceConfigBuilder, DeviceInfo, PeerConfigBuilder, PeerInfo}; +use wgctrl::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfigBuilder, PeerInfo}; mod data_store; mod util; @@ -155,7 +155,11 @@ impl std::error::Error for ClientError { } } -fn update_hosts_file(interface: &str, hosts_path: PathBuf, peers: &Vec) -> Result<(), Error> { +fn update_hosts_file( + interface: &InterfaceName, + hosts_path: PathBuf, + peers: &Vec, +) -> Result<(), Error> { println!( "{} updating {} with the latest peers.", "[*]".dimmed(), @@ -189,6 +193,8 @@ fn install(invite: &Path, hosts_file: Option) -> Result<(), Error> { return Err("An interface with this name already exists in innernet.".into()); } + let iface = iface.parse()?; + println!("{} bringing up the interface.", "[*]".dimmed()); wg::up( &iface, @@ -267,7 +273,7 @@ fn install(invite: &Path, hosts_file: Option) -> Result<(), Error> { ", star = "[*]".dimmed(), - interface = iface.yellow(), + interface = iface.to_string().yellow(), installed = "installed".green(), systemctl_enable = "systemctl enable --now innernet@".yellow(), ); @@ -276,7 +282,7 @@ fn install(invite: &Path, hosts_file: Option) -> Result<(), Error> { } fn up( - interface: &str, + interface: &InterfaceName, loop_interval: Option, hosts_path: Option, ) -> Result<(), Error> { @@ -292,7 +298,7 @@ fn up( } fn fetch( - interface: &str, + interface: &InterfaceName, bring_up_interface: bool, hosts_path: Option, ) -> Result<(), Error> { @@ -398,7 +404,7 @@ fn fetch( println!( "\n{} updated interface {}\n", "[*]".dimmed(), - interface.yellow() + interface.as_str_lossy().yellow() ); } else { println!("{}", " peers are already up to date.".green()); @@ -410,7 +416,7 @@ fn fetch( Ok(()) } -fn add_cidr(interface: &str) -> Result<(), Error> { +fn add_cidr(interface: &InterfaceName) -> Result<(), Error> { let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; println!("Fetching CIDRs"); let cidrs: Vec = http_get(&server.internal_endpoint, "/admin/cidrs")?; @@ -435,7 +441,7 @@ fn add_cidr(interface: &str) -> Result<(), Error> { Ok(()) } -fn add_peer(interface: &str) -> Result<(), Error> { +fn add_peer(interface: &InterfaceName) -> Result<(), Error> { let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; println!("Fetching CIDRs"); let cidrs: Vec = http_get(&server.internal_endpoint, "/admin/cidrs")?; @@ -462,7 +468,7 @@ fn add_peer(interface: &str) -> Result<(), Error> { Ok(()) } -fn enable_or_disable_peer(interface: &str, enable: bool) -> Result<(), Error> { +fn enable_or_disable_peer(interface: &InterfaceName, enable: bool) -> Result<(), Error> { let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; println!("Fetching peers."); let peers: Vec = http_get(&server.internal_endpoint, "/admin/peers")?; @@ -482,7 +488,7 @@ fn enable_or_disable_peer(interface: &str, enable: bool) -> Result<(), Error> { Ok(()) } -fn add_association(interface: &str) -> Result<(), Error> { +fn add_association(interface: &InterfaceName) -> Result<(), Error> { let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; println!("Fetching CIDRs"); @@ -504,7 +510,7 @@ fn add_association(interface: &str) -> Result<(), Error> { Ok(()) } -fn delete_association(interface: &str) -> Result<(), Error> { +fn delete_association(interface: &InterfaceName) -> Result<(), Error> { let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; println!("Fetching CIDRs"); @@ -525,7 +531,7 @@ fn delete_association(interface: &str) -> Result<(), Error> { Ok(()) } -fn list_associations(interface: &str) -> Result<(), Error> { +fn list_associations(interface: &InterfaceName) -> Result<(), Error> { let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; println!("Fetching CIDRs"); let cidrs: Vec = http_get(&server.internal_endpoint, "/admin/cidrs")?; @@ -555,7 +561,7 @@ fn list_associations(interface: &str) -> Result<(), Error> { Ok(()) } -fn set_listen_port(interface: &str, unset: bool) -> Result<(), Error> { +fn set_listen_port(interface: &InterfaceName, unset: bool) -> Result<(), Error> { let mut config = InterfaceConfig::from_interface(interface)?; if let Some(listen_port) = prompts::set_listen_port(&config.interface, unset)? { @@ -572,7 +578,7 @@ fn set_listen_port(interface: &str, unset: bool) -> Result<(), Error> { Ok(()) } -fn override_endpoint(interface: &str, unset: bool) -> Result<(), Error> { +fn override_endpoint(interface: &InterfaceName, unset: bool) -> Result<(), Error> { let config = InterfaceConfig::from_interface(interface)?; if !unset && config.interface.listen_port.is_none() { println!( @@ -597,10 +603,8 @@ fn override_endpoint(interface: &str, unset: bool) -> Result<(), Error> { } fn show(short: bool, tree: bool, interface: Option) -> Result<(), Error> { - let interfaces = interface.map_or_else( - || DeviceInfo::enumerate(), - |interface| Ok(vec![interface.to_string()]), - )?; + let interfaces = + interface.map_or_else(|| DeviceInfo::enumerate(), |interface| Ok(vec![*interface]))?; let devices = interfaces.into_iter().filter_map(|name| { DataStore::open(&name) @@ -678,7 +682,7 @@ fn print_interface(device_info: &DeviceInfo, me: &Peer, short: bool) -> Result<( .to_base64(); if short { - println!("{}", device_info.name.green().bold()); + println!("{}", device_info.name.to_string().green().bold()); println!( " {} {}: {} ({}...)", "(you)".bold(), @@ -690,7 +694,7 @@ fn print_interface(device_info: &DeviceInfo, me: &Peer, short: bool) -> Result<( println!( "{}: {} ({}...)", "interface".green().bold(), - device_info.name.green(), + device_info.name.to_string().green(), public_key[..10].yellow() ); if !short { diff --git a/server/src/endpoints.rs b/server/src/endpoints.rs index c19c178..e3d8083 100644 --- a/server/src/endpoints.rs +++ b/server/src/endpoints.rs @@ -1,6 +1,6 @@ use crossbeam::channel::{self, select}; use dashmap::DashMap; -use wgctrl::DeviceInfo; +use wgctrl::{DeviceInfo, InterfaceName}; use std::{io, net::SocketAddr, sync::Arc, thread, time::Duration}; @@ -18,7 +18,7 @@ impl std::ops::Deref for Endpoints { } impl Endpoints { - pub fn new(iface: &str) -> Result { + pub fn new(iface: &InterfaceName) -> Result { let endpoints = Arc::new(DashMap::new()); let (stop_tx, stop_rx) = channel::bounded(1); diff --git a/server/src/initialize.rs b/server/src/initialize.rs index afaf98f..883d306 100644 --- a/server/src/initialize.rs +++ b/server/src/initialize.rs @@ -100,6 +100,9 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> { (name, root_cidr) }); + // This probably won't error because of the `hostname_validator` regex. + let name = name.parse()?; + let endpoint: SocketAddr = conf.endpoint.unwrap_or_else(|| { prompts::ask_endpoint() .map_err(|_| println!("failed to get endpoint.")) @@ -131,7 +134,7 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> { config.write_to_path(&config_path)?; let db_init_data = DbInitData { - root_cidr_name: name.clone(), + root_cidr_name: name.to_string(), root_cidr, server_cidr, our_ip, @@ -176,7 +179,7 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> { ", star = "[*]".dimmed(), - interface = name.yellow(), + interface = name.to_string().yellow(), created = "created".green(), wg_manage_server = "innernet-server".yellow(), add_cidr = "add-cidr".yellow(), diff --git a/server/src/main.rs b/server/src/main.rs index af432b3..55b6639 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -19,7 +19,7 @@ use std::{ }; use structopt::StructOpt; use warp::Filter; -use wgctrl::{DeviceConfigBuilder, DeviceInfo, PeerConfigBuilder}; +use wgctrl::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfigBuilder}; pub mod api; pub mod db; @@ -67,7 +67,7 @@ pub type Db = Arc>; pub struct Context { pub db: Db, pub endpoints: Arc, - pub interface: String, + pub interface: InterfaceName, } pub struct Session { @@ -140,10 +140,10 @@ impl ServerConfig { .unwrap_or(*SERVER_DATABASE_DIR) } - fn database_path(&self, interface: &str) -> PathBuf { + fn database_path(&self, interface: &InterfaceName) -> PathBuf { PathBuf::new() .join(self.database_dir()) - .join(interface) + .join(interface.to_string()) .with_extension("db") } @@ -153,10 +153,10 @@ impl ServerConfig { .unwrap_or(*SERVER_CONFIG_DIR) } - fn config_path(&self, interface: &str) -> PathBuf { + fn config_path(&self, interface: &InterfaceName) -> PathBuf { PathBuf::new() .join(self.config_dir()) - .join(interface) + .join(interface.to_string()) .with_extension("conf") } } @@ -192,7 +192,7 @@ async fn main() -> Result<(), Box> { } fn open_database_connection( - interface: &str, + interface: &InterfaceName, conf: &ServerConfig, ) -> Result> { let database_path = conf.database_path(&interface); @@ -207,8 +207,8 @@ fn open_database_connection( Ok(Connection::open(&database_path)?) } -fn add_peer(interface: &str, conf: &ServerConfig) -> Result<(), Error> { - let config = ConfigFile::from_file(conf.config_path(&interface))?; +fn add_peer(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> { + let config = ConfigFile::from_file(conf.config_path(interface))?; let conn = open_database_connection(interface, conf)?; let peers = DatabasePeer::list(&conn)? .into_iter() @@ -245,7 +245,7 @@ fn add_peer(interface: &str, conf: &ServerConfig) -> Result<(), Error> { Ok(()) } -fn add_cidr(interface: &str, conf: &ServerConfig) -> Result<(), Error> { +fn add_cidr(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> { let conn = open_database_connection(interface, conf)?; let cidrs = DatabaseCidr::list(&conn)?; if let Some(cidr_request) = shared::prompts::add_cidr(&cidrs)? { @@ -268,9 +268,9 @@ fn add_cidr(interface: &str, conf: &ServerConfig) -> Result<(), Error> { Ok(()) } -async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> { - let config = ConfigFile::from_file(conf.config_path(&interface))?; - let conn = open_database_connection(&interface, conf)?; +async fn serve(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> { + let config = ConfigFile::from_file(conf.config_path(interface))?; + let conn = open_database_connection(interface, conf)?; // Foreign key constraints aren't on in SQLite by default. Enable. conn.pragma_update(None, "foreign_keys", &1)?; @@ -282,7 +282,7 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> { log::info!("bringing up interface."); wg::up( - &interface, + interface, &config.private_key, IpNetwork::new(config.address, config.network_cidr_prefix)?, Some(config.listen_port), @@ -300,7 +300,7 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> { let db = Arc::new(Mutex::new(conn)); let context = Context { db, - interface: interface.to_string(), + interface: *interface, endpoints, }; @@ -334,11 +334,11 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> { /// /// See https://github.com/tonarino/innernet/issues/26 for more details. #[cfg(target_os = "linux")] -fn get_listener(addr: SocketAddr, interface: &str) -> Result { +fn get_listener(addr: SocketAddr, interface: &InterfaceName) -> Result { let listener = TcpListener::bind(&addr)?; listener.set_nonblocking(true)?; let sock = socket2::Socket::from(listener); - sock.bind_device(Some(interface.as_bytes()))?; + sock.bind_device(Some(interface.as_str_lossy().as_bytes()))?; Ok(sock.into()) } @@ -349,7 +349,7 @@ fn get_listener(addr: SocketAddr, interface: &str) -> Result /// /// See https://github.com/tonarino/innernet/issues/26 for more details. #[cfg(not(target_os = "linux"))] -fn get_listener(addr: SocketAddr, _interface: &str) -> Result { +fn get_listener(addr: SocketAddr, _interface: &InterfaceName) -> Result { let listener = TcpListener::bind(&addr)?; listener.set_nonblocking(true)?; Ok(listener) diff --git a/server/src/test.rs b/server/src/test.rs index 4a8c9b1..d389320 100644 --- a/server/src/test.rs +++ b/server/src/test.rs @@ -12,7 +12,7 @@ use shared::{Cidr, CidrContents, PeerContents}; use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use tempfile::TempDir; use warp::test::RequestBuilder; -use wgctrl::KeyPair; +use wgctrl::{InterfaceName, KeyPair}; pub const ROOT_CIDR: &str = "10.80.0.0/15"; pub const SERVER_CIDR: &str = "10.80.0.1/32"; @@ -45,7 +45,7 @@ pub const USER2_PEER_ID: i64 = 6; pub struct Server { pub db: Arc>, endpoints: Arc, - interface: String, + interface: InterfaceName, conf: ServerConfig, // The directory will be removed during destruction. _test_dir: TempDir, @@ -69,6 +69,7 @@ impl Server { }; init_wizard(&conf).map_err(|_| anyhow!("init_wizard failed"))?; + let interface = interface.parse().unwrap(); // Add developer CIDR and user CIDR and some peers for testing. let db = Connection::open(&conf.database_path(&interface))?; db.pragma_update(None, "foreign_keys", &1)?; diff --git a/shared/src/interface_config.rs b/shared/src/interface_config.rs index cecaee3..053691c 100644 --- a/shared/src/interface_config.rs +++ b/shared/src/interface_config.rs @@ -9,6 +9,7 @@ use std::{ os::unix::fs::PermissionsExt, path::{Path, PathBuf}, }; +use wgctrl::InterfaceName; #[derive(Deserialize, Serialize, Debug)] #[serde(rename_all = "kebab-case")] @@ -92,7 +93,7 @@ impl InterfaceConfig { } /// Overwrites the config file if it already exists. - pub fn write_to_interface(&self, interface: &str) -> Result { + pub fn write_to_interface(&self, interface: &InterfaceName) -> Result { let path = Self::build_config_file_path(interface)?; File::create(&path) .with_path(&path)? @@ -104,13 +105,15 @@ impl InterfaceConfig { Ok(toml::from_slice(&std::fs::read(&path).with_path(path)?)?) } - pub fn from_interface(interface: &str) -> Result { + pub fn from_interface(interface: &InterfaceName) -> Result { Self::from_file(Self::build_config_file_path(interface)?) } - fn build_config_file_path(interface: &str) -> Result { + fn build_config_file_path(interface: &InterfaceName) -> Result { ensure_dirs_exist(&[*CLIENT_CONFIG_PATH])?; - Ok(CLIENT_CONFIG_PATH.join(interface).with_extension("conf")) + Ok(CLIENT_CONFIG_PATH + .join(interface.to_string()) + .with_extension("conf")) } } diff --git a/shared/src/lib.rs b/shared/src/lib.rs index e413b10..0b0e723 100644 --- a/shared/src/lib.rs +++ b/shared/src/lib.rs @@ -13,7 +13,7 @@ use std::{ str::FromStr, time::Duration, }; -use wgctrl::{Key, PeerConfig, PeerConfigBuilder}; +use wgctrl::{InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder}; pub mod interface_config; pub mod prompts; @@ -65,23 +65,24 @@ impl std::error::Error for WrappedIoError {} #[derive(Debug, Clone)] pub struct Interface { - name: String, + name: InterfaceName, } impl FromStr for Interface { - type Err = &'static str; + type Err = String; fn from_str(name: &str) -> Result { - let s = name.to_string(); - hostname_validator(&s)?; - Ok(Self { - name: name.to_string(), - }) + let name = name.to_string(); + hostname_validator(&name)?; + let name = name + .parse() + .map_err(|e: InvalidInterfaceName| e.to_string())?; + Ok(Self { name }) } } impl Deref for Interface { - type Target = str; + type Target = InterfaceName; fn deref(&self) -> &Self::Target { &self.name diff --git a/shared/src/prompts.rs b/shared/src/prompts.rs index f1c316b..5506793 100644 --- a/shared/src/prompts.rs +++ b/shared/src/prompts.rs @@ -9,7 +9,7 @@ use ipnetwork::IpNetwork; use lazy_static::lazy_static; use regex::Regex; use std::net::{IpAddr, SocketAddr}; -use wgctrl::KeyPair; +use wgctrl::{InterfaceName, KeyPair}; lazy_static! { static ref THEME: ColorfulTheme = ColorfulTheme::default(); @@ -239,7 +239,7 @@ pub fn enable_or_disable_peer(peers: &[Peer], enable: bool) -> Result Result { let output = Command::new(bin).args(args).output()?; @@ -22,8 +22,9 @@ fn cmd(bin: &str, args: &[&str]) -> Result { } #[cfg(target_os = "macos")] -pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> { - let real_interface = wgctrl::backends::userspace::resolve_tun(interface).with_str(interface)?; +pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), Error> { + let real_interface = + wgctrl::backends::userspace::resolve_tun(interface).with_str(interface.to_string())?; if addr.is_ipv4() { cmd( @@ -47,20 +48,21 @@ pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> { } #[cfg(target_os = "linux")] -pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> { +pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), Error> { + let interface = interface.to_string(); cmd( "ip", - &["address", "replace", &addr.to_string(), "dev", interface], + &["address", "replace", &addr.to_string(), "dev", &interface], )?; let _ = cmd( "ip", - &["link", "set", "mtu", "1420", "up", "dev", interface], + &["link", "set", "mtu", "1420", "up", "dev", &interface], ); Ok(()) } pub fn up( - interface: &str, + interface: &InterfaceName, private_key: &str, address: IpNetwork, listen_port: Option, @@ -85,7 +87,7 @@ pub fn up( Ok(()) } -pub fn set_listen_port(interface: &str, listen_port: Option) -> Result<(), Error> { +pub fn set_listen_port(interface: &InterfaceName, listen_port: Option) -> Result<(), Error> { let mut device = DeviceConfigBuilder::new(); if let Some(listen_port) = listen_port { device = device.set_listen_port(listen_port); @@ -98,24 +100,24 @@ pub fn set_listen_port(interface: &str, listen_port: Option) -> Result<(), } #[cfg(target_os = "linux")] -pub fn down(interface: &str) -> Result<(), Error> { - Ok(wgctrl::delete_interface(interface).with_str(interface)?) +pub fn down(interface: &InterfaceName) -> Result<(), Error> { + Ok(wgctrl::delete_interface(&interface).with_str(interface.to_string())?) } #[cfg(not(target_os = "linux"))] -pub fn down(interface: &str) -> Result<(), Error> { +pub fn down(interface: &InterfaceName) -> Result<(), Error> { wgctrl::backends::userspace::delete_interface(interface) - .with_str(interface) + .with_str(interface.to_string()) .map_err(Error::from) } /// Add a route in the OS's routing table to get traffic flowing through this interface. /// Returns an error if the process doesn't exit successfully, otherwise returns /// true if the route was changed, false if the route already exists. -pub fn add_route(interface: &str, cidr: IpNetwork) -> Result { +pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result { if cfg!(target_os = "macos") { let real_interface = - wgctrl::backends::userspace::resolve_tun(interface).with_str(interface)?; + wgctrl::backends::userspace::resolve_tun(interface).with_str(interface.to_string())?; let output = cmd( "route", &[ @@ -141,7 +143,13 @@ pub fn add_route(interface: &str, cidr: IpNetwork) -> Result { // TODO(mcginty): use the netlink interface on linux to modify routing table. let _ = cmd( "ip", - &["route", "add", &cidr.to_string(), "dev", &interface], + &[ + "route", + "add", + &cidr.to_string(), + "dev", + &interface.to_string(), + ], ); Ok(false) } diff --git a/wgctrl-rs/src/backends/kernel.rs b/wgctrl-rs/src/backends/kernel.rs index 38a1d35..c87acbb 100644 --- a/wgctrl-rs/src/backends/kernel.rs +++ b/wgctrl-rs/src/backends/kernel.rs @@ -1,6 +1,6 @@ use crate::{ - device::AllowedIp, DeviceConfigBuilder, DeviceInfo, InvalidKey, PeerConfig, PeerConfigBuilder, - PeerInfo, PeerStats, + device::AllowedIp, DeviceConfigBuilder, DeviceInfo, InterfaceName, InvalidInterfaceName, + InvalidKey, PeerConfig, PeerConfigBuilder, PeerInfo, PeerStats, }; use wgctrl_sys::{timespec64, wg_device_flags as wgdf, wg_peer_flags as wgpf}; @@ -71,8 +71,10 @@ impl<'a> From<&'a wgctrl_sys::wg_peer> for PeerInfo { impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo { fn from(raw: &wgctrl_sys::wg_device) -> DeviceInfo { + // SAFETY: The name string buffer came directly from wgctrl so its NUL terminated. + let name = unsafe { InterfaceName::from_wg(raw.name) }; DeviceInfo { - name: parse_device_name(raw.name), + name, public_key: if (raw.flags & wgdf::WGDEVICE_HAS_PUBLIC_KEY).0 > 0 { Some(Key::from_raw(raw.public_key)) } else { @@ -98,15 +100,6 @@ impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo { } } -fn parse_device_name(name: [c_char; 16]) -> String { - let name: &[u8; 16] = unsafe { &*((&name) as *const _ as *const [u8; 16]) }; - let idx: usize = name - .iter() - .position(|x| *x == 0) - .expect("Interface name too long?"); - unsafe { str::from_utf8_unchecked(&name[..idx]) }.to_owned() -} - fn parse_peers(dev: &wgctrl_sys::wg_device) -> Vec { let mut result = Vec::new(); @@ -297,15 +290,6 @@ fn encode_peers( (first_peer, last_peer) } -fn encode_name(name: &str) -> [c_char; 16] { - let slice = unsafe { &*(name.as_bytes() as *const _ as *const [c_char]) }; - - let mut result = [c_char::default(); 16]; - result[..slice.len()].copy_from_slice(slice); - - result -} - pub fn exists() -> bool { // Try to load the wireguard module if it isn't already. // This is only called once per lifetime of the process. @@ -320,7 +304,7 @@ pub fn exists() -> bool { Path::new("/sys/module/wireguard").is_dir() } -pub fn enumerate() -> Result, io::Error> { +pub fn enumerate() -> Result, io::Error> { let base = unsafe { wgctrl_sys::wg_list_device_names() }; if base.is_null() { @@ -340,7 +324,12 @@ pub fn enumerate() -> Result, io::Error> { } current = unsafe { current.add(len + 1) }; - result.push(unsafe { str::from_utf8_unchecked(next_dev) }.to_owned()); + + let interface: InterfaceName = str::from_utf8(next_dev) + .map_err(|_| InvalidInterfaceName::InvalidChars)? + .parse()?; + + result.push(interface); } unsafe { libc::free(base as *mut libc::c_void) }; @@ -348,18 +337,17 @@ pub fn enumerate() -> Result, io::Error> { Ok(result) } -pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> { +pub fn apply(builder: DeviceConfigBuilder, iface: &InterfaceName) -> io::Result<()> { let (first_peer, last_peer) = encode_peers(builder.peers); - let iface_str = CString::new(iface)?; - let result = unsafe { wgctrl_sys::wg_add_device(iface_str.as_ptr()) }; + let result = unsafe { wgctrl_sys::wg_add_device(iface.as_ptr()) }; match result { 0 | -17 => {}, _ => return Err(io::Error::last_os_error()), }; let mut wg_device = Box::new(wgctrl_sys::wg_device { - name: encode_name(iface), + name: iface.into_inner(), ifindex: 0, public_key: wgctrl_sys::wg_key::default(), private_key: wgctrl_sys::wg_key::default(), @@ -406,15 +394,13 @@ pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> { } } -pub fn get_by_name(name: &str) -> Result { +pub fn get_by_name(name: &InterfaceName) -> Result { let mut device: *mut wgctrl_sys::wg_device = ptr::null_mut(); - let cs = CString::new(name)?; - let result = unsafe { wgctrl_sys::wg_get_device( (&mut device) as *mut _ as *mut *mut wgctrl_sys::wg_device, - cs.as_ptr(), + name.as_ptr(), ) }; @@ -429,9 +415,8 @@ pub fn get_by_name(name: &str) -> Result { result } -pub fn delete_interface(iface: &str) -> io::Result<()> { - let iface_str = CString::new(iface)?; - let result = unsafe { wgctrl_sys::wg_del_device(iface_str.as_ptr()) }; +pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { + let result = unsafe { wgctrl_sys::wg_del_device(iface.as_ptr()) }; if result == 0 { Ok(()) diff --git a/wgctrl-rs/src/backends/userspace.rs b/wgctrl-rs/src/backends/userspace.rs index 7042a29..107696b 100644 --- a/wgctrl-rs/src/backends/userspace.rs +++ b/wgctrl-rs/src/backends/userspace.rs @@ -1,11 +1,11 @@ -use crate::{DeviceConfigBuilder, DeviceInfo, PeerConfig, PeerInfo, PeerStats}; +use crate::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfig, PeerInfo, PeerStats}; #[cfg(target_os = "linux")] use crate::Key; use std::{ - fs, io, - io::{prelude::*, BufReader}, + fs, + io::{self, prelude::*, BufReader}, os::unix::net::UnixStream, path::{Path, PathBuf}, process::Command, @@ -28,31 +28,31 @@ fn get_base_folder() -> io::Result { } } -fn get_namefile(name: &str) -> io::Result { - Ok(get_base_folder()?.join(&format!("{}.name", name))) +fn get_namefile(name: &InterfaceName) -> io::Result { + Ok(get_base_folder()?.join(&format!("{}.name", name.as_str_lossy()))) } -fn get_socketfile(name: &str) -> io::Result { +fn get_socketfile(name: &InterfaceName) -> io::Result { Ok(get_base_folder()?.join(&format!("{}.sock", resolve_tun(name)?))) } -fn open_socket(name: &str) -> io::Result { +fn open_socket(name: &InterfaceName) -> io::Result { UnixStream::connect(get_socketfile(name)?) } -pub fn resolve_tun(name: &str) -> io::Result { +pub fn resolve_tun(name: &InterfaceName) -> io::Result { let namefile = get_namefile(name)?; Ok(fs::read_to_string(namefile)?.trim().to_string()) } -pub fn delete_interface(name: &str) -> io::Result<()> { +pub fn delete_interface(name: &InterfaceName) -> io::Result<()> { fs::remove_file(get_socketfile(name)?).ok(); fs::remove_file(get_namefile(name)?).ok(); Ok(()) } -pub fn enumerate() -> Result, io::Error> { +pub fn enumerate() -> Result, io::Error> { use std::ffi::OsStr; let mut interfaces = vec![]; @@ -61,7 +61,7 @@ pub fn enumerate() -> Result, io::Error> { if path.extension() == Some(OsStr::new("name")) { let stem = path.file_stem().map(|stem| stem.to_str()).flatten(); if let Some(name) = stem { - interfaces.push(name.to_string()); + interfaces.push(name.parse()?); } } } @@ -100,9 +100,10 @@ impl From for DeviceInfo { } impl ConfigParser { - fn new(name: &str) -> Self { + /// Returns `None` if an invalid device name was provided. + fn new(name: &InterfaceName) -> Self { let device_info = DeviceInfo { - name: name.to_string(), + name: *name, public_key: None, private_key: None, fwmark: None, @@ -228,13 +229,14 @@ impl ConfigParser { } } -pub fn get_by_name(name: &str) -> Result { +pub fn get_by_name(name: &InterfaceName) -> Result { let mut sock = open_socket(name)?; sock.write_all(b"get=1\n\n")?; let mut reader = BufReader::new(sock); let mut buf = String::new(); let mut parser = ConfigParser::new(name); + loop { match reader.read_line(&mut buf)? { 0 | 1 if buf == "\n" => break, @@ -261,7 +263,7 @@ fn get_userspace_implementation() -> String { .unwrap_or_else(|_| "wireguard-go".to_string()) } -pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> { +pub fn apply(builder: DeviceConfigBuilder, iface: &InterfaceName) -> io::Result<()> { // If we can't open a configuration socket to an existing interface, try starting it. let mut sock = match open_socket(iface) { Err(_) => { diff --git a/wgctrl-rs/src/config.rs b/wgctrl-rs/src/config.rs index cb37b79..c3e231d 100644 --- a/wgctrl-rs/src/config.rs +++ b/wgctrl-rs/src/config.rs @@ -1,6 +1,6 @@ use crate::{ backends, - device::{AllowedIp, PeerConfig}, + device::{AllowedIp, InterfaceName, PeerConfig}, key::{Key, KeyPair}, }; @@ -35,7 +35,7 @@ use std::{ /// peer.set_endpoint(server_addr) /// .replace_allowed_ips() /// .allow_all_ips() -/// }).apply("wg-example"); +/// }).apply(&"wg-example".parse().unwrap()); /// /// println!("Send these keys to your peer: {:#?}", peer_keypair); /// @@ -171,16 +171,16 @@ impl DeviceConfigBuilder { /// /// An interface with the provided name will be created if one does not exist already. #[cfg(target_os = "linux")] - pub fn apply(self, iface: &str) -> io::Result<()> { + pub fn apply(self, iface: &InterfaceName) -> io::Result<()> { if backends::kernel::exists() { - backends::kernel::apply(self, iface) + backends::kernel::apply(self, &iface) } else { backends::userspace::apply(self, iface) } } #[cfg(not(target_os = "linux"))] - pub fn apply(self, iface: &str) -> io::Result<()> { + pub fn apply(self, iface: &InterfaceName) -> io::Result<()> { backends::userspace::apply(self, iface) } } @@ -215,7 +215,7 @@ impl Default for DeviceConfigBuilder { /// .add_allowed_ip("192.168.1.2".parse()?, 32); /// /// // update our existing configuration with the new peer -/// DeviceConfigBuilder::new().add_peer(peer).apply("wg-example"); +/// DeviceConfigBuilder::new().add_peer(peer).apply(&"wg-example".parse().unwrap()); /// /// println!("Send these keys to your peer: {:#?}", peer_keypair); /// @@ -353,6 +353,6 @@ impl PeerConfigBuilder { /// Deletes an existing WireGuard interface by name. #[cfg(target_os = "linux")] -pub fn delete_interface(iface: &str) -> io::Result<()> { +pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> { backends::kernel::delete_interface(iface) } diff --git a/wgctrl-rs/src/device.rs b/wgctrl-rs/src/device.rs index 1dd1479..81cb83b 100644 --- a/wgctrl-rs/src/device.rs +++ b/wgctrl-rs/src/device.rs @@ -1,7 +1,13 @@ +use libc::c_char; + use crate::{backends, key::Key}; use std::{ + borrow::Cow, + ffi::CStr, + fmt, net::{IpAddr, SocketAddr}, + str::FromStr, time::SystemTime, }; @@ -86,7 +92,7 @@ pub struct PeerInfo { #[derive(Debug, PartialEq, Eq, Clone)] pub struct DeviceInfo { /// The interface name of this device - pub name: String, + pub name: InterfaceName, /// The public encryption key of this interface (if present) pub public_key: Option, /// The private encryption key of this interface (if present) @@ -103,6 +109,132 @@ pub struct DeviceInfo { pub(crate) __cant_construct_me: (), } +type RawInterfaceName = [c_char; libc::IFNAMSIZ]; + +/// The name of a Wireguard interface device. +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct InterfaceName(RawInterfaceName); + +impl FromStr for InterfaceName { + type Err = InvalidInterfaceName; + + /// Attempts to parse a Rust string as a valid Linux interface name. + /// + /// Extra validation logic ported from [iproute2](https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/utils.c#n827) + fn from_str(name: &str) -> Result { + let len = name.len(); + // Ensure its short enough to include a trailing NUL + if len > (libc::IFNAMSIZ - 1) { + return Err(InvalidInterfaceName::TooLong(len)); + } + + if len == 0 || name.trim_start_matches('\0').is_empty() { + return Err(InvalidInterfaceName::Empty); + } + + let mut buf = [c_char::default(); libc::IFNAMSIZ]; + // Check for interior NULs and other invalid characters. + for (out, b) in buf.iter_mut().zip(name.as_bytes()[..(len - 1)].iter()) { + if *b == 0 { + return Err(InvalidInterfaceName::InteriorNul); + } + + if *b == b'/' || b.is_ascii_whitespace() { + return Err(InvalidInterfaceName::InvalidChars); + } + + *out = *b as i8; + } + + Ok(Self(buf)) + } +} + +impl InterfaceName { + #[cfg(target_os = "linux")] + /// Creates a new [InterfaceName](Self). + /// + /// ## Safety + /// + /// The caller must ensure that `name` is a valid C string terminated by a NUL. + pub(crate) unsafe fn from_wg(name: RawInterfaceName) -> Self { + Self(name) + } + + /// Returns a human-readable form of the device name. + /// + /// Only use this when the interface name was constructed from a Rust string. + pub fn as_str_lossy(&self) -> Cow<'_, str> { + // SAFETY: These are C strings coming from wgctrl, so they are correctly NUL terminated. + unsafe { CStr::from_ptr(self.0.as_ptr()) }.to_string_lossy() + } + + #[cfg(target_os = "linux")] + /// Returns a pointer to the inner byte buffer for FFI calls. + pub(crate) fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } + + #[cfg(target_os = "linux")] + /// Consumes this interface name, returning its raw byte buffer. + pub(crate) fn into_inner(self) -> RawInterfaceName { + self.0 + } +} + +impl fmt::Debug for InterfaceName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.as_str_lossy()) + } +} + +impl fmt::Display for InterfaceName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.as_str_lossy()) + } +} + +/// An interface name was bad. +#[derive(Debug, PartialEq)] +pub enum InvalidInterfaceName { + /// Provided name had an interior NUL byte. + InteriorNul, + /// Provided name was longer then the interface name length limit + /// of the system. + TooLong(usize), + + // These checks are done in the kernel as well, but no reason to let bad names + // get that far: https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/utils.c?id=1f420318bda3cc62156e89e1b56d60cc744b48ad#n827. + /// Interface name was an empty string. + Empty, + /// Interface name contained a `/` or space character. + InvalidChars, +} + +impl fmt::Display for InvalidInterfaceName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InteriorNul => f.write_str("interface name contained an interior NUL byte"), + Self::TooLong(size) => write!( + f, + "interface name was {} bytes long but the system's max is {}", + size, + libc::IFNAMSIZ + ), + Self::Empty => f.write_str("an empty interface name was provided"), + Self::InvalidChars => f.write_str("interface name contained slash or space characters"), + } + } +} + +impl From for std::io::Error { + fn from(e: InvalidInterfaceName) -> Self { + std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()) + } +} + +impl std::error::Error for InvalidInterfaceName {} + impl DeviceInfo { /// Enumerates all WireGuard interfaces currently present in the system /// and returns their names. @@ -110,7 +242,7 @@ impl DeviceInfo { /// You can use [`get_by_name`](DeviceInfo::get_by_name) to retrieve more /// detailed information on each interface. #[cfg(target_os = "linux")] - pub fn enumerate() -> Result, std::io::Error> { + pub fn enumerate() -> Result, std::io::Error> { if backends::kernel::exists() { backends::kernel::enumerate() } else { @@ -119,12 +251,12 @@ impl DeviceInfo { } #[cfg(not(target_os = "linux"))] - pub fn enumerate() -> Result, std::io::Error> { + pub fn enumerate() -> Result, std::io::Error> { crate::backends::userspace::enumerate() } #[cfg(target_os = "linux")] - pub fn get_by_name(name: &str) -> Result { + pub fn get_by_name(name: &InterfaceName) -> Result { if backends::kernel::exists() { backends::kernel::get_by_name(name) } else { @@ -133,7 +265,7 @@ impl DeviceInfo { } #[cfg(not(target_os = "linux"))] - pub fn get_by_name(name: &str) -> Result { + pub fn get_by_name(name: &InterfaceName) -> Result { backends::userspace::get_by_name(name) } @@ -150,7 +282,9 @@ impl DeviceInfo { #[cfg(test)] mod tests { - use crate::{DeviceConfigBuilder, KeyPair, PeerConfigBuilder}; + use crate::{ + DeviceConfigBuilder, InterfaceName, InvalidInterfaceName, KeyPair, PeerConfigBuilder, + }; const TEST_INTERFACE: &str = "wgctrl-test"; use super::*; @@ -166,9 +300,10 @@ mod tests { for keypair in &keypairs { builder = builder.add_peer(PeerConfigBuilder::new(&keypair.public)) } - builder.apply(TEST_INTERFACE).unwrap(); + let interface = TEST_INTERFACE.parse().unwrap(); + builder.apply(&interface).unwrap(); - let device = DeviceInfo::get_by_name(TEST_INTERFACE).unwrap(); + let device = DeviceInfo::get_by_name(&interface).unwrap(); for keypair in &keypairs { assert!(device @@ -179,4 +314,24 @@ mod tests { device.delete().unwrap(); } + + #[test] + fn test_interface_names() { + assert!("wg-01".parse::().is_ok()); + assert!("longer-nul\0".parse::().is_ok()); + + let invalid_names = &[ + ("", InvalidInterfaceName::Empty), // Empty Rust string + ("\0", InvalidInterfaceName::Empty), // Empty C string + ("ifname\0nul", InvalidInterfaceName::InteriorNul), // Contains interior NUL + ("if name", InvalidInterfaceName::InvalidChars), // Contains a space + ("ifna/me", InvalidInterfaceName::InvalidChars), // Contains a slash + ("if na/me", InvalidInterfaceName::InvalidChars), // Contains a space and slash + ("interfacelongname", InvalidInterfaceName::TooLong(17)), // Too long + ]; + + for (name, expected) in invalid_names { + assert!(name.parse::().as_ref() == Err(expected)) + } + } }