wgctrl-sys: Remove some unsafe in the kernel backend

Validates WireGuard interfaces against the linux specification for interface names.
Refactor userspace and other OSes to use InterfaceName
pull/37/head
BlackHoleFox 2021-04-08 20:28:37 -05:00 committed by GitHub
parent 67c69ecfa0
commit b1e1ff8f4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 307 additions and 142 deletions

View File

@ -6,6 +6,7 @@ use std::{
io::{Read, Seek, SeekFrom, Write}, io::{Read, Seek, SeekFrom, Write},
path::Path, path::Path,
}; };
use wgctrl::InterfaceName;
#[derive(Debug)] #[derive(Debug)]
pub struct DataStore { pub struct DataStore {
@ -38,19 +39,21 @@ impl DataStore {
Ok(Self { file, contents }) Ok(Self { file, contents })
} }
fn _open(interface: &str, create: bool) -> Result<Self, Error> { fn _open(interface: &InterfaceName, create: bool) -> Result<Self, Error> {
ensure_dirs_exist(&[*CLIENT_DATA_PATH])?; ensure_dirs_exist(&[*CLIENT_DATA_PATH])?;
Self::open_with_path( Self::open_with_path(
CLIENT_DATA_PATH.join(interface).with_extension("json"), CLIENT_DATA_PATH
.join(interface.to_string())
.with_extension("json"),
create, create,
) )
} }
pub fn open(interface: &str) -> Result<Self, Error> { pub fn open(interface: &InterfaceName) -> Result<Self, Error> {
Self::_open(interface, false) Self::_open(interface, false)
} }
pub fn open_or_create(interface: &str) -> Result<Self, Error> { pub fn open_or_create(interface: &InterfaceName) -> Result<Self, Error> {
Self::_open(interface, true) Self::_open(interface, true)
} }

View File

@ -14,7 +14,7 @@ use std::{
time::Duration, time::Duration,
}; };
use structopt::StructOpt; use structopt::StructOpt;
use wgctrl::{DeviceConfigBuilder, DeviceInfo, PeerConfigBuilder, PeerInfo}; use wgctrl::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfigBuilder, PeerInfo};
mod data_store; mod data_store;
mod util; mod util;
@ -155,7 +155,11 @@ impl std::error::Error for ClientError {
} }
} }
fn update_hosts_file(interface: &str, hosts_path: PathBuf, peers: &Vec<Peer>) -> Result<(), Error> { fn update_hosts_file(
interface: &InterfaceName,
hosts_path: PathBuf,
peers: &Vec<Peer>,
) -> Result<(), Error> {
println!( println!(
"{} updating {} with the latest peers.", "{} updating {} with the latest peers.",
"[*]".dimmed(), "[*]".dimmed(),
@ -189,6 +193,8 @@ fn install(invite: &Path, hosts_file: Option<PathBuf>) -> Result<(), Error> {
return Err("An interface with this name already exists in innernet.".into()); return Err("An interface with this name already exists in innernet.".into());
} }
let iface = iface.parse()?;
println!("{} bringing up the interface.", "[*]".dimmed()); println!("{} bringing up the interface.", "[*]".dimmed());
wg::up( wg::up(
&iface, &iface,
@ -267,7 +273,7 @@ fn install(invite: &Path, hosts_file: Option<PathBuf>) -> Result<(), Error> {
", ",
star = "[*]".dimmed(), star = "[*]".dimmed(),
interface = iface.yellow(), interface = iface.to_string().yellow(),
installed = "installed".green(), installed = "installed".green(),
systemctl_enable = "systemctl enable --now innernet@".yellow(), systemctl_enable = "systemctl enable --now innernet@".yellow(),
); );
@ -276,7 +282,7 @@ fn install(invite: &Path, hosts_file: Option<PathBuf>) -> Result<(), Error> {
} }
fn up( fn up(
interface: &str, interface: &InterfaceName,
loop_interval: Option<Duration>, loop_interval: Option<Duration>,
hosts_path: Option<PathBuf>, hosts_path: Option<PathBuf>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -292,7 +298,7 @@ fn up(
} }
fn fetch( fn fetch(
interface: &str, interface: &InterfaceName,
bring_up_interface: bool, bring_up_interface: bool,
hosts_path: Option<PathBuf>, hosts_path: Option<PathBuf>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -398,7 +404,7 @@ fn fetch(
println!( println!(
"\n{} updated interface {}\n", "\n{} updated interface {}\n",
"[*]".dimmed(), "[*]".dimmed(),
interface.yellow() interface.as_str_lossy().yellow()
); );
} else { } else {
println!("{}", " peers are already up to date.".green()); println!("{}", " peers are already up to date.".green());
@ -410,7 +416,7 @@ fn fetch(
Ok(()) Ok(())
} }
fn add_cidr(interface: &str) -> Result<(), Error> { fn add_cidr(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs"); println!("Fetching CIDRs");
let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?; let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?;
@ -435,7 +441,7 @@ fn add_cidr(interface: &str) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn add_peer(interface: &str) -> Result<(), Error> { fn add_peer(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs"); println!("Fetching CIDRs");
let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?; let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?;
@ -462,7 +468,7 @@ fn add_peer(interface: &str) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn enable_or_disable_peer(interface: &str, enable: bool) -> Result<(), Error> { fn enable_or_disable_peer(interface: &InterfaceName, enable: bool) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching peers."); println!("Fetching peers.");
let peers: Vec<Peer> = http_get(&server.internal_endpoint, "/admin/peers")?; let peers: Vec<Peer> = http_get(&server.internal_endpoint, "/admin/peers")?;
@ -482,7 +488,7 @@ fn enable_or_disable_peer(interface: &str, enable: bool) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn add_association(interface: &str) -> Result<(), Error> { fn add_association(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs"); println!("Fetching CIDRs");
@ -504,7 +510,7 @@ fn add_association(interface: &str) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn delete_association(interface: &str) -> Result<(), Error> { fn delete_association(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs"); println!("Fetching CIDRs");
@ -525,7 +531,7 @@ fn delete_association(interface: &str) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn list_associations(interface: &str) -> Result<(), Error> { fn list_associations(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?; let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs"); println!("Fetching CIDRs");
let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?; let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?;
@ -555,7 +561,7 @@ fn list_associations(interface: &str) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn set_listen_port(interface: &str, unset: bool) -> Result<(), Error> { fn set_listen_port(interface: &InterfaceName, unset: bool) -> Result<(), Error> {
let mut config = InterfaceConfig::from_interface(interface)?; let mut config = InterfaceConfig::from_interface(interface)?;
if let Some(listen_port) = prompts::set_listen_port(&config.interface, unset)? { if let Some(listen_port) = prompts::set_listen_port(&config.interface, unset)? {
@ -572,7 +578,7 @@ fn set_listen_port(interface: &str, unset: bool) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn override_endpoint(interface: &str, unset: bool) -> Result<(), Error> { fn override_endpoint(interface: &InterfaceName, unset: bool) -> Result<(), Error> {
let config = InterfaceConfig::from_interface(interface)?; let config = InterfaceConfig::from_interface(interface)?;
if !unset && config.interface.listen_port.is_none() { if !unset && config.interface.listen_port.is_none() {
println!( println!(
@ -597,10 +603,8 @@ fn override_endpoint(interface: &str, unset: bool) -> Result<(), Error> {
} }
fn show(short: bool, tree: bool, interface: Option<Interface>) -> Result<(), Error> { fn show(short: bool, tree: bool, interface: Option<Interface>) -> Result<(), Error> {
let interfaces = interface.map_or_else( let interfaces =
|| DeviceInfo::enumerate(), interface.map_or_else(|| DeviceInfo::enumerate(), |interface| Ok(vec![*interface]))?;
|interface| Ok(vec![interface.to_string()]),
)?;
let devices = interfaces.into_iter().filter_map(|name| { let devices = interfaces.into_iter().filter_map(|name| {
DataStore::open(&name) DataStore::open(&name)
@ -678,7 +682,7 @@ fn print_interface(device_info: &DeviceInfo, me: &Peer, short: bool) -> Result<(
.to_base64(); .to_base64();
if short { if short {
println!("{}", device_info.name.green().bold()); println!("{}", device_info.name.to_string().green().bold());
println!( println!(
" {} {}: {} ({}...)", " {} {}: {} ({}...)",
"(you)".bold(), "(you)".bold(),
@ -690,7 +694,7 @@ fn print_interface(device_info: &DeviceInfo, me: &Peer, short: bool) -> Result<(
println!( println!(
"{}: {} ({}...)", "{}: {} ({}...)",
"interface".green().bold(), "interface".green().bold(),
device_info.name.green(), device_info.name.to_string().green(),
public_key[..10].yellow() public_key[..10].yellow()
); );
if !short { if !short {

View File

@ -1,6 +1,6 @@
use crossbeam::channel::{self, select}; use crossbeam::channel::{self, select};
use dashmap::DashMap; use dashmap::DashMap;
use wgctrl::DeviceInfo; use wgctrl::{DeviceInfo, InterfaceName};
use std::{io, net::SocketAddr, sync::Arc, thread, time::Duration}; use std::{io, net::SocketAddr, sync::Arc, thread, time::Duration};
@ -18,7 +18,7 @@ impl std::ops::Deref for Endpoints {
} }
impl Endpoints { impl Endpoints {
pub fn new(iface: &str) -> Result<Self, io::Error> { pub fn new(iface: &InterfaceName) -> Result<Self, io::Error> {
let endpoints = Arc::new(DashMap::new()); let endpoints = Arc::new(DashMap::new());
let (stop_tx, stop_rx) = channel::bounded(1); let (stop_tx, stop_rx) = channel::bounded(1);

View File

@ -100,6 +100,9 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> {
(name, root_cidr) (name, root_cidr)
}); });
// This probably won't error because of the `hostname_validator` regex.
let name = name.parse()?;
let endpoint: SocketAddr = conf.endpoint.unwrap_or_else(|| { let endpoint: SocketAddr = conf.endpoint.unwrap_or_else(|| {
prompts::ask_endpoint() prompts::ask_endpoint()
.map_err(|_| println!("failed to get endpoint.")) .map_err(|_| println!("failed to get endpoint."))
@ -131,7 +134,7 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> {
config.write_to_path(&config_path)?; config.write_to_path(&config_path)?;
let db_init_data = DbInitData { let db_init_data = DbInitData {
root_cidr_name: name.clone(), root_cidr_name: name.to_string(),
root_cidr, root_cidr,
server_cidr, server_cidr,
our_ip, our_ip,
@ -176,7 +179,7 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> {
", ",
star = "[*]".dimmed(), star = "[*]".dimmed(),
interface = name.yellow(), interface = name.to_string().yellow(),
created = "created".green(), created = "created".green(),
wg_manage_server = "innernet-server".yellow(), wg_manage_server = "innernet-server".yellow(),
add_cidr = "add-cidr".yellow(), add_cidr = "add-cidr".yellow(),

View File

@ -19,7 +19,7 @@ use std::{
}; };
use structopt::StructOpt; use structopt::StructOpt;
use warp::Filter; use warp::Filter;
use wgctrl::{DeviceConfigBuilder, DeviceInfo, PeerConfigBuilder}; use wgctrl::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfigBuilder};
pub mod api; pub mod api;
pub mod db; pub mod db;
@ -67,7 +67,7 @@ pub type Db = Arc<Mutex<Connection>>;
pub struct Context { pub struct Context {
pub db: Db, pub db: Db,
pub endpoints: Arc<Endpoints>, pub endpoints: Arc<Endpoints>,
pub interface: String, pub interface: InterfaceName,
} }
pub struct Session { pub struct Session {
@ -140,10 +140,10 @@ impl ServerConfig {
.unwrap_or(*SERVER_DATABASE_DIR) .unwrap_or(*SERVER_DATABASE_DIR)
} }
fn database_path(&self, interface: &str) -> PathBuf { fn database_path(&self, interface: &InterfaceName) -> PathBuf {
PathBuf::new() PathBuf::new()
.join(self.database_dir()) .join(self.database_dir())
.join(interface) .join(interface.to_string())
.with_extension("db") .with_extension("db")
} }
@ -153,10 +153,10 @@ impl ServerConfig {
.unwrap_or(*SERVER_CONFIG_DIR) .unwrap_or(*SERVER_CONFIG_DIR)
} }
fn config_path(&self, interface: &str) -> PathBuf { fn config_path(&self, interface: &InterfaceName) -> PathBuf {
PathBuf::new() PathBuf::new()
.join(self.config_dir()) .join(self.config_dir())
.join(interface) .join(interface.to_string())
.with_extension("conf") .with_extension("conf")
} }
} }
@ -192,7 +192,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
fn open_database_connection( fn open_database_connection(
interface: &str, interface: &InterfaceName,
conf: &ServerConfig, conf: &ServerConfig,
) -> Result<rusqlite::Connection, Box<dyn std::error::Error>> { ) -> Result<rusqlite::Connection, Box<dyn std::error::Error>> {
let database_path = conf.database_path(&interface); let database_path = conf.database_path(&interface);
@ -207,8 +207,8 @@ fn open_database_connection(
Ok(Connection::open(&database_path)?) Ok(Connection::open(&database_path)?)
} }
fn add_peer(interface: &str, conf: &ServerConfig) -> Result<(), Error> { fn add_peer(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(&interface))?; let config = ConfigFile::from_file(conf.config_path(interface))?;
let conn = open_database_connection(interface, conf)?; let conn = open_database_connection(interface, conf)?;
let peers = DatabasePeer::list(&conn)? let peers = DatabasePeer::list(&conn)?
.into_iter() .into_iter()
@ -245,7 +245,7 @@ fn add_peer(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn add_cidr(interface: &str, conf: &ServerConfig) -> Result<(), Error> { fn add_cidr(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> {
let conn = open_database_connection(interface, conf)?; let conn = open_database_connection(interface, conf)?;
let cidrs = DatabaseCidr::list(&conn)?; let cidrs = DatabaseCidr::list(&conn)?;
if let Some(cidr_request) = shared::prompts::add_cidr(&cidrs)? { if let Some(cidr_request) = shared::prompts::add_cidr(&cidrs)? {
@ -268,9 +268,9 @@ fn add_cidr(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
Ok(()) Ok(())
} }
async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> { async fn serve(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(&interface))?; let config = ConfigFile::from_file(conf.config_path(interface))?;
let conn = open_database_connection(&interface, conf)?; let conn = open_database_connection(interface, conf)?;
// Foreign key constraints aren't on in SQLite by default. Enable. // Foreign key constraints aren't on in SQLite by default. Enable.
conn.pragma_update(None, "foreign_keys", &1)?; conn.pragma_update(None, "foreign_keys", &1)?;
@ -282,7 +282,7 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
log::info!("bringing up interface."); log::info!("bringing up interface.");
wg::up( wg::up(
&interface, interface,
&config.private_key, &config.private_key,
IpNetwork::new(config.address, config.network_cidr_prefix)?, IpNetwork::new(config.address, config.network_cidr_prefix)?,
Some(config.listen_port), Some(config.listen_port),
@ -300,7 +300,7 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
let db = Arc::new(Mutex::new(conn)); let db = Arc::new(Mutex::new(conn));
let context = Context { let context = Context {
db, db,
interface: interface.to_string(), interface: *interface,
endpoints, endpoints,
}; };
@ -334,11 +334,11 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
/// ///
/// See https://github.com/tonarino/innernet/issues/26 for more details. /// See https://github.com/tonarino/innernet/issues/26 for more details.
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
fn get_listener(addr: SocketAddr, interface: &str) -> Result<TcpListener, Error> { fn get_listener(addr: SocketAddr, interface: &InterfaceName) -> Result<TcpListener, Error> {
let listener = TcpListener::bind(&addr)?; let listener = TcpListener::bind(&addr)?;
listener.set_nonblocking(true)?; listener.set_nonblocking(true)?;
let sock = socket2::Socket::from(listener); let sock = socket2::Socket::from(listener);
sock.bind_device(Some(interface.as_bytes()))?; sock.bind_device(Some(interface.as_str_lossy().as_bytes()))?;
Ok(sock.into()) Ok(sock.into())
} }
@ -349,7 +349,7 @@ fn get_listener(addr: SocketAddr, interface: &str) -> Result<TcpListener, Error>
/// ///
/// See https://github.com/tonarino/innernet/issues/26 for more details. /// See https://github.com/tonarino/innernet/issues/26 for more details.
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
fn get_listener(addr: SocketAddr, _interface: &str) -> Result<TcpListener, Error> { fn get_listener(addr: SocketAddr, _interface: &InterfaceName) -> Result<TcpListener, Error> {
let listener = TcpListener::bind(&addr)?; let listener = TcpListener::bind(&addr)?;
listener.set_nonblocking(true)?; listener.set_nonblocking(true)?;
Ok(listener) Ok(listener)

View File

@ -12,7 +12,7 @@ use shared::{Cidr, CidrContents, PeerContents};
use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use tempfile::TempDir; use tempfile::TempDir;
use warp::test::RequestBuilder; use warp::test::RequestBuilder;
use wgctrl::KeyPair; use wgctrl::{InterfaceName, KeyPair};
pub const ROOT_CIDR: &str = "10.80.0.0/15"; pub const ROOT_CIDR: &str = "10.80.0.0/15";
pub const SERVER_CIDR: &str = "10.80.0.1/32"; pub const SERVER_CIDR: &str = "10.80.0.1/32";
@ -45,7 +45,7 @@ pub const USER2_PEER_ID: i64 = 6;
pub struct Server { pub struct Server {
pub db: Arc<Mutex<Connection>>, pub db: Arc<Mutex<Connection>>,
endpoints: Arc<Endpoints>, endpoints: Arc<Endpoints>,
interface: String, interface: InterfaceName,
conf: ServerConfig, conf: ServerConfig,
// The directory will be removed during destruction. // The directory will be removed during destruction.
_test_dir: TempDir, _test_dir: TempDir,
@ -69,6 +69,7 @@ impl Server {
}; };
init_wizard(&conf).map_err(|_| anyhow!("init_wizard failed"))?; init_wizard(&conf).map_err(|_| anyhow!("init_wizard failed"))?;
let interface = interface.parse().unwrap();
// Add developer CIDR and user CIDR and some peers for testing. // Add developer CIDR and user CIDR and some peers for testing.
let db = Connection::open(&conf.database_path(&interface))?; let db = Connection::open(&conf.database_path(&interface))?;
db.pragma_update(None, "foreign_keys", &1)?; db.pragma_update(None, "foreign_keys", &1)?;

View File

@ -9,6 +9,7 @@ use std::{
os::unix::fs::PermissionsExt, os::unix::fs::PermissionsExt,
path::{Path, PathBuf}, path::{Path, PathBuf},
}; };
use wgctrl::InterfaceName;
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
@ -92,7 +93,7 @@ impl InterfaceConfig {
} }
/// Overwrites the config file if it already exists. /// Overwrites the config file if it already exists.
pub fn write_to_interface(&self, interface: &str) -> Result<PathBuf, Error> { pub fn write_to_interface(&self, interface: &InterfaceName) -> Result<PathBuf, Error> {
let path = Self::build_config_file_path(interface)?; let path = Self::build_config_file_path(interface)?;
File::create(&path) File::create(&path)
.with_path(&path)? .with_path(&path)?
@ -104,13 +105,15 @@ impl InterfaceConfig {
Ok(toml::from_slice(&std::fs::read(&path).with_path(path)?)?) Ok(toml::from_slice(&std::fs::read(&path).with_path(path)?)?)
} }
pub fn from_interface(interface: &str) -> Result<Self, Error> { pub fn from_interface(interface: &InterfaceName) -> Result<Self, Error> {
Self::from_file(Self::build_config_file_path(interface)?) Self::from_file(Self::build_config_file_path(interface)?)
} }
fn build_config_file_path(interface: &str) -> Result<PathBuf, Error> { fn build_config_file_path(interface: &InterfaceName) -> Result<PathBuf, Error> {
ensure_dirs_exist(&[*CLIENT_CONFIG_PATH])?; ensure_dirs_exist(&[*CLIENT_CONFIG_PATH])?;
Ok(CLIENT_CONFIG_PATH.join(interface).with_extension("conf")) Ok(CLIENT_CONFIG_PATH
.join(interface.to_string())
.with_extension("conf"))
} }
} }

View File

@ -13,7 +13,7 @@ use std::{
str::FromStr, str::FromStr,
time::Duration, time::Duration,
}; };
use wgctrl::{Key, PeerConfig, PeerConfigBuilder}; use wgctrl::{InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder};
pub mod interface_config; pub mod interface_config;
pub mod prompts; pub mod prompts;
@ -65,23 +65,24 @@ impl std::error::Error for WrappedIoError {}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Interface { pub struct Interface {
name: String, name: InterfaceName,
} }
impl FromStr for Interface { impl FromStr for Interface {
type Err = &'static str; type Err = String;
fn from_str(name: &str) -> Result<Self, Self::Err> { fn from_str(name: &str) -> Result<Self, Self::Err> {
let s = name.to_string(); let name = name.to_string();
hostname_validator(&s)?; hostname_validator(&name)?;
Ok(Self { let name = name
name: name.to_string(), .parse()
}) .map_err(|e: InvalidInterfaceName| e.to_string())?;
Ok(Self { name })
} }
} }
impl Deref for Interface { impl Deref for Interface {
type Target = str; type Target = InterfaceName;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.name &self.name

View File

@ -9,7 +9,7 @@ use ipnetwork::IpNetwork;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use wgctrl::KeyPair; use wgctrl::{InterfaceName, KeyPair};
lazy_static! { lazy_static! {
static ref THEME: ColorfulTheme = ColorfulTheme::default(); static ref THEME: ColorfulTheme = ColorfulTheme::default();
@ -239,7 +239,7 @@ pub fn enable_or_disable_peer(peers: &[Peer], enable: bool) -> Result<Option<Pee
/// Confirm and write a innernet invitation file after a peer has been created. /// Confirm and write a innernet invitation file after a peer has been created.
pub fn save_peer_invitation( pub fn save_peer_invitation(
network_name: &str, network_name: &InterfaceName,
peer: &Peer, peer: &Peer,
server_peer: &Peer, server_peer: &Peer,
root_cidr: &Cidr, root_cidr: &Cidr,

View File

@ -4,7 +4,7 @@ use std::{
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
process::{self, Command}, process::{self, Command},
}; };
use wgctrl::{DeviceConfigBuilder, PeerConfigBuilder}; use wgctrl::{DeviceConfigBuilder, InterfaceName, PeerConfigBuilder};
fn cmd(bin: &str, args: &[&str]) -> Result<process::Output, Error> { fn cmd(bin: &str, args: &[&str]) -> Result<process::Output, Error> {
let output = Command::new(bin).args(args).output()?; let output = Command::new(bin).args(args).output()?;
@ -22,8 +22,9 @@ fn cmd(bin: &str, args: &[&str]) -> Result<process::Output, Error> {
} }
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> { pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), Error> {
let real_interface = wgctrl::backends::userspace::resolve_tun(interface).with_str(interface)?; let real_interface =
wgctrl::backends::userspace::resolve_tun(interface).with_str(interface.to_string())?;
if addr.is_ipv4() { if addr.is_ipv4() {
cmd( cmd(
@ -47,20 +48,21 @@ pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> { pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), Error> {
let interface = interface.to_string();
cmd( cmd(
"ip", "ip",
&["address", "replace", &addr.to_string(), "dev", interface], &["address", "replace", &addr.to_string(), "dev", &interface],
)?; )?;
let _ = cmd( let _ = cmd(
"ip", "ip",
&["link", "set", "mtu", "1420", "up", "dev", interface], &["link", "set", "mtu", "1420", "up", "dev", &interface],
); );
Ok(()) Ok(())
} }
pub fn up( pub fn up(
interface: &str, interface: &InterfaceName,
private_key: &str, private_key: &str,
address: IpNetwork, address: IpNetwork,
listen_port: Option<u16>, listen_port: Option<u16>,
@ -85,7 +87,7 @@ pub fn up(
Ok(()) Ok(())
} }
pub fn set_listen_port(interface: &str, listen_port: Option<u16>) -> Result<(), Error> { pub fn set_listen_port(interface: &InterfaceName, listen_port: Option<u16>) -> Result<(), Error> {
let mut device = DeviceConfigBuilder::new(); let mut device = DeviceConfigBuilder::new();
if let Some(listen_port) = listen_port { if let Some(listen_port) = listen_port {
device = device.set_listen_port(listen_port); device = device.set_listen_port(listen_port);
@ -98,24 +100,24 @@ pub fn set_listen_port(interface: &str, listen_port: Option<u16>) -> Result<(),
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn down(interface: &str) -> Result<(), Error> { pub fn down(interface: &InterfaceName) -> Result<(), Error> {
Ok(wgctrl::delete_interface(interface).with_str(interface)?) Ok(wgctrl::delete_interface(&interface).with_str(interface.to_string())?)
} }
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
pub fn down(interface: &str) -> Result<(), Error> { pub fn down(interface: &InterfaceName) -> Result<(), Error> {
wgctrl::backends::userspace::delete_interface(interface) wgctrl::backends::userspace::delete_interface(interface)
.with_str(interface) .with_str(interface.to_string())
.map_err(Error::from) .map_err(Error::from)
} }
/// Add a route in the OS's routing table to get traffic flowing through this interface. /// Add a route in the OS's routing table to get traffic flowing through this interface.
/// Returns an error if the process doesn't exit successfully, otherwise returns /// Returns an error if the process doesn't exit successfully, otherwise returns
/// true if the route was changed, false if the route already exists. /// true if the route was changed, false if the route already exists.
pub fn add_route(interface: &str, cidr: IpNetwork) -> Result<bool, Error> { pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result<bool, Error> {
if cfg!(target_os = "macos") { if cfg!(target_os = "macos") {
let real_interface = let real_interface =
wgctrl::backends::userspace::resolve_tun(interface).with_str(interface)?; wgctrl::backends::userspace::resolve_tun(interface).with_str(interface.to_string())?;
let output = cmd( let output = cmd(
"route", "route",
&[ &[
@ -141,7 +143,13 @@ pub fn add_route(interface: &str, cidr: IpNetwork) -> Result<bool, Error> {
// TODO(mcginty): use the netlink interface on linux to modify routing table. // TODO(mcginty): use the netlink interface on linux to modify routing table.
let _ = cmd( let _ = cmd(
"ip", "ip",
&["route", "add", &cidr.to_string(), "dev", &interface], &[
"route",
"add",
&cidr.to_string(),
"dev",
&interface.to_string(),
],
); );
Ok(false) Ok(false)
} }

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
device::AllowedIp, DeviceConfigBuilder, DeviceInfo, InvalidKey, PeerConfig, PeerConfigBuilder, device::AllowedIp, DeviceConfigBuilder, DeviceInfo, InterfaceName, InvalidInterfaceName,
PeerInfo, PeerStats, InvalidKey, PeerConfig, PeerConfigBuilder, PeerInfo, PeerStats,
}; };
use wgctrl_sys::{timespec64, wg_device_flags as wgdf, wg_peer_flags as wgpf}; use wgctrl_sys::{timespec64, wg_device_flags as wgdf, wg_peer_flags as wgpf};
@ -71,8 +71,10 @@ impl<'a> From<&'a wgctrl_sys::wg_peer> for PeerInfo {
impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo { impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo {
fn from(raw: &wgctrl_sys::wg_device) -> DeviceInfo { fn from(raw: &wgctrl_sys::wg_device) -> DeviceInfo {
// SAFETY: The name string buffer came directly from wgctrl so its NUL terminated.
let name = unsafe { InterfaceName::from_wg(raw.name) };
DeviceInfo { DeviceInfo {
name: parse_device_name(raw.name), name,
public_key: if (raw.flags & wgdf::WGDEVICE_HAS_PUBLIC_KEY).0 > 0 { public_key: if (raw.flags & wgdf::WGDEVICE_HAS_PUBLIC_KEY).0 > 0 {
Some(Key::from_raw(raw.public_key)) Some(Key::from_raw(raw.public_key))
} else { } else {
@ -98,15 +100,6 @@ impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo {
} }
} }
fn parse_device_name(name: [c_char; 16]) -> String {
let name: &[u8; 16] = unsafe { &*((&name) as *const _ as *const [u8; 16]) };
let idx: usize = name
.iter()
.position(|x| *x == 0)
.expect("Interface name too long?");
unsafe { str::from_utf8_unchecked(&name[..idx]) }.to_owned()
}
fn parse_peers(dev: &wgctrl_sys::wg_device) -> Vec<PeerInfo> { fn parse_peers(dev: &wgctrl_sys::wg_device) -> Vec<PeerInfo> {
let mut result = Vec::new(); let mut result = Vec::new();
@ -297,15 +290,6 @@ fn encode_peers(
(first_peer, last_peer) (first_peer, last_peer)
} }
fn encode_name(name: &str) -> [c_char; 16] {
let slice = unsafe { &*(name.as_bytes() as *const _ as *const [c_char]) };
let mut result = [c_char::default(); 16];
result[..slice.len()].copy_from_slice(slice);
result
}
pub fn exists() -> bool { pub fn exists() -> bool {
// Try to load the wireguard module if it isn't already. // Try to load the wireguard module if it isn't already.
// This is only called once per lifetime of the process. // This is only called once per lifetime of the process.
@ -320,7 +304,7 @@ pub fn exists() -> bool {
Path::new("/sys/module/wireguard").is_dir() Path::new("/sys/module/wireguard").is_dir()
} }
pub fn enumerate() -> Result<Vec<String>, io::Error> { pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
let base = unsafe { wgctrl_sys::wg_list_device_names() }; let base = unsafe { wgctrl_sys::wg_list_device_names() };
if base.is_null() { if base.is_null() {
@ -340,7 +324,12 @@ pub fn enumerate() -> Result<Vec<String>, io::Error> {
} }
current = unsafe { current.add(len + 1) }; current = unsafe { current.add(len + 1) };
result.push(unsafe { str::from_utf8_unchecked(next_dev) }.to_owned());
let interface: InterfaceName = str::from_utf8(next_dev)
.map_err(|_| InvalidInterfaceName::InvalidChars)?
.parse()?;
result.push(interface);
} }
unsafe { libc::free(base as *mut libc::c_void) }; unsafe { libc::free(base as *mut libc::c_void) };
@ -348,18 +337,17 @@ pub fn enumerate() -> Result<Vec<String>, io::Error> {
Ok(result) Ok(result)
} }
pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> { pub fn apply(builder: DeviceConfigBuilder, iface: &InterfaceName) -> io::Result<()> {
let (first_peer, last_peer) = encode_peers(builder.peers); let (first_peer, last_peer) = encode_peers(builder.peers);
let iface_str = CString::new(iface)?; let result = unsafe { wgctrl_sys::wg_add_device(iface.as_ptr()) };
let result = unsafe { wgctrl_sys::wg_add_device(iface_str.as_ptr()) };
match result { match result {
0 | -17 => {}, 0 | -17 => {},
_ => return Err(io::Error::last_os_error()), _ => return Err(io::Error::last_os_error()),
}; };
let mut wg_device = Box::new(wgctrl_sys::wg_device { let mut wg_device = Box::new(wgctrl_sys::wg_device {
name: encode_name(iface), name: iface.into_inner(),
ifindex: 0, ifindex: 0,
public_key: wgctrl_sys::wg_key::default(), public_key: wgctrl_sys::wg_key::default(),
private_key: wgctrl_sys::wg_key::default(), private_key: wgctrl_sys::wg_key::default(),
@ -406,15 +394,13 @@ pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> {
} }
} }
pub fn get_by_name(name: &str) -> Result<DeviceInfo, io::Error> { pub fn get_by_name(name: &InterfaceName) -> Result<DeviceInfo, io::Error> {
let mut device: *mut wgctrl_sys::wg_device = ptr::null_mut(); let mut device: *mut wgctrl_sys::wg_device = ptr::null_mut();
let cs = CString::new(name)?;
let result = unsafe { let result = unsafe {
wgctrl_sys::wg_get_device( wgctrl_sys::wg_get_device(
(&mut device) as *mut _ as *mut *mut wgctrl_sys::wg_device, (&mut device) as *mut _ as *mut *mut wgctrl_sys::wg_device,
cs.as_ptr(), name.as_ptr(),
) )
}; };
@ -429,9 +415,8 @@ pub fn get_by_name(name: &str) -> Result<DeviceInfo, io::Error> {
result result
} }
pub fn delete_interface(iface: &str) -> io::Result<()> { pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
let iface_str = CString::new(iface)?; let result = unsafe { wgctrl_sys::wg_del_device(iface.as_ptr()) };
let result = unsafe { wgctrl_sys::wg_del_device(iface_str.as_ptr()) };
if result == 0 { if result == 0 {
Ok(()) Ok(())

View File

@ -1,11 +1,11 @@
use crate::{DeviceConfigBuilder, DeviceInfo, PeerConfig, PeerInfo, PeerStats}; use crate::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfig, PeerInfo, PeerStats};
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
use crate::Key; use crate::Key;
use std::{ use std::{
fs, io, fs,
io::{prelude::*, BufReader}, io::{self, prelude::*, BufReader},
os::unix::net::UnixStream, os::unix::net::UnixStream,
path::{Path, PathBuf}, path::{Path, PathBuf},
process::Command, process::Command,
@ -28,31 +28,31 @@ fn get_base_folder() -> io::Result<PathBuf> {
} }
} }
fn get_namefile(name: &str) -> io::Result<PathBuf> { fn get_namefile(name: &InterfaceName) -> io::Result<PathBuf> {
Ok(get_base_folder()?.join(&format!("{}.name", name))) Ok(get_base_folder()?.join(&format!("{}.name", name.as_str_lossy())))
} }
fn get_socketfile(name: &str) -> io::Result<PathBuf> { fn get_socketfile(name: &InterfaceName) -> io::Result<PathBuf> {
Ok(get_base_folder()?.join(&format!("{}.sock", resolve_tun(name)?))) Ok(get_base_folder()?.join(&format!("{}.sock", resolve_tun(name)?)))
} }
fn open_socket(name: &str) -> io::Result<UnixStream> { fn open_socket(name: &InterfaceName) -> io::Result<UnixStream> {
UnixStream::connect(get_socketfile(name)?) UnixStream::connect(get_socketfile(name)?)
} }
pub fn resolve_tun(name: &str) -> io::Result<String> { pub fn resolve_tun(name: &InterfaceName) -> io::Result<String> {
let namefile = get_namefile(name)?; let namefile = get_namefile(name)?;
Ok(fs::read_to_string(namefile)?.trim().to_string()) Ok(fs::read_to_string(namefile)?.trim().to_string())
} }
pub fn delete_interface(name: &str) -> io::Result<()> { pub fn delete_interface(name: &InterfaceName) -> io::Result<()> {
fs::remove_file(get_socketfile(name)?).ok(); fs::remove_file(get_socketfile(name)?).ok();
fs::remove_file(get_namefile(name)?).ok(); fs::remove_file(get_namefile(name)?).ok();
Ok(()) Ok(())
} }
pub fn enumerate() -> Result<Vec<String>, io::Error> { pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
use std::ffi::OsStr; use std::ffi::OsStr;
let mut interfaces = vec![]; let mut interfaces = vec![];
@ -61,7 +61,7 @@ pub fn enumerate() -> Result<Vec<String>, io::Error> {
if path.extension() == Some(OsStr::new("name")) { if path.extension() == Some(OsStr::new("name")) {
let stem = path.file_stem().map(|stem| stem.to_str()).flatten(); let stem = path.file_stem().map(|stem| stem.to_str()).flatten();
if let Some(name) = stem { if let Some(name) = stem {
interfaces.push(name.to_string()); interfaces.push(name.parse()?);
} }
} }
} }
@ -100,9 +100,10 @@ impl From<ConfigParser> for DeviceInfo {
} }
impl ConfigParser { impl ConfigParser {
fn new(name: &str) -> Self { /// Returns `None` if an invalid device name was provided.
fn new(name: &InterfaceName) -> Self {
let device_info = DeviceInfo { let device_info = DeviceInfo {
name: name.to_string(), name: *name,
public_key: None, public_key: None,
private_key: None, private_key: None,
fwmark: None, fwmark: None,
@ -228,13 +229,14 @@ impl ConfigParser {
} }
} }
pub fn get_by_name(name: &str) -> Result<DeviceInfo, io::Error> { pub fn get_by_name(name: &InterfaceName) -> Result<DeviceInfo, io::Error> {
let mut sock = open_socket(name)?; let mut sock = open_socket(name)?;
sock.write_all(b"get=1\n\n")?; sock.write_all(b"get=1\n\n")?;
let mut reader = BufReader::new(sock); let mut reader = BufReader::new(sock);
let mut buf = String::new(); let mut buf = String::new();
let mut parser = ConfigParser::new(name); let mut parser = ConfigParser::new(name);
loop { loop {
match reader.read_line(&mut buf)? { match reader.read_line(&mut buf)? {
0 | 1 if buf == "\n" => break, 0 | 1 if buf == "\n" => break,
@ -261,7 +263,7 @@ fn get_userspace_implementation() -> String {
.unwrap_or_else(|_| "wireguard-go".to_string()) .unwrap_or_else(|_| "wireguard-go".to_string())
} }
pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> { pub fn apply(builder: DeviceConfigBuilder, iface: &InterfaceName) -> io::Result<()> {
// If we can't open a configuration socket to an existing interface, try starting it. // If we can't open a configuration socket to an existing interface, try starting it.
let mut sock = match open_socket(iface) { let mut sock = match open_socket(iface) {
Err(_) => { Err(_) => {

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
backends, backends,
device::{AllowedIp, PeerConfig}, device::{AllowedIp, InterfaceName, PeerConfig},
key::{Key, KeyPair}, key::{Key, KeyPair},
}; };
@ -35,7 +35,7 @@ use std::{
/// peer.set_endpoint(server_addr) /// peer.set_endpoint(server_addr)
/// .replace_allowed_ips() /// .replace_allowed_ips()
/// .allow_all_ips() /// .allow_all_ips()
/// }).apply("wg-example"); /// }).apply(&"wg-example".parse().unwrap());
/// ///
/// println!("Send these keys to your peer: {:#?}", peer_keypair); /// println!("Send these keys to your peer: {:#?}", peer_keypair);
/// ///
@ -171,16 +171,16 @@ impl DeviceConfigBuilder {
/// ///
/// An interface with the provided name will be created if one does not exist already. /// An interface with the provided name will be created if one does not exist already.
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn apply(self, iface: &str) -> io::Result<()> { pub fn apply(self, iface: &InterfaceName) -> io::Result<()> {
if backends::kernel::exists() { if backends::kernel::exists() {
backends::kernel::apply(self, iface) backends::kernel::apply(self, &iface)
} else { } else {
backends::userspace::apply(self, iface) backends::userspace::apply(self, iface)
} }
} }
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
pub fn apply(self, iface: &str) -> io::Result<()> { pub fn apply(self, iface: &InterfaceName) -> io::Result<()> {
backends::userspace::apply(self, iface) backends::userspace::apply(self, iface)
} }
} }
@ -215,7 +215,7 @@ impl Default for DeviceConfigBuilder {
/// .add_allowed_ip("192.168.1.2".parse()?, 32); /// .add_allowed_ip("192.168.1.2".parse()?, 32);
/// ///
/// // update our existing configuration with the new peer /// // update our existing configuration with the new peer
/// DeviceConfigBuilder::new().add_peer(peer).apply("wg-example"); /// DeviceConfigBuilder::new().add_peer(peer).apply(&"wg-example".parse().unwrap());
/// ///
/// println!("Send these keys to your peer: {:#?}", peer_keypair); /// println!("Send these keys to your peer: {:#?}", peer_keypair);
/// ///
@ -353,6 +353,6 @@ impl PeerConfigBuilder {
/// Deletes an existing WireGuard interface by name. /// Deletes an existing WireGuard interface by name.
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn delete_interface(iface: &str) -> io::Result<()> { pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
backends::kernel::delete_interface(iface) backends::kernel::delete_interface(iface)
} }

View File

@ -1,7 +1,13 @@
use libc::c_char;
use crate::{backends, key::Key}; use crate::{backends, key::Key};
use std::{ use std::{
borrow::Cow,
ffi::CStr,
fmt,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
str::FromStr,
time::SystemTime, time::SystemTime,
}; };
@ -86,7 +92,7 @@ pub struct PeerInfo {
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct DeviceInfo { pub struct DeviceInfo {
/// The interface name of this device /// The interface name of this device
pub name: String, pub name: InterfaceName,
/// The public encryption key of this interface (if present) /// The public encryption key of this interface (if present)
pub public_key: Option<Key>, pub public_key: Option<Key>,
/// The private encryption key of this interface (if present) /// The private encryption key of this interface (if present)
@ -103,6 +109,132 @@ pub struct DeviceInfo {
pub(crate) __cant_construct_me: (), pub(crate) __cant_construct_me: (),
} }
type RawInterfaceName = [c_char; libc::IFNAMSIZ];
/// The name of a Wireguard interface device.
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct InterfaceName(RawInterfaceName);
impl FromStr for InterfaceName {
type Err = InvalidInterfaceName;
/// Attempts to parse a Rust string as a valid Linux interface name.
///
/// Extra validation logic ported from [iproute2](https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/utils.c#n827)
fn from_str(name: &str) -> Result<Self, InvalidInterfaceName> {
let len = name.len();
// Ensure its short enough to include a trailing NUL
if len > (libc::IFNAMSIZ - 1) {
return Err(InvalidInterfaceName::TooLong(len));
}
if len == 0 || name.trim_start_matches('\0').is_empty() {
return Err(InvalidInterfaceName::Empty);
}
let mut buf = [c_char::default(); libc::IFNAMSIZ];
// Check for interior NULs and other invalid characters.
for (out, b) in buf.iter_mut().zip(name.as_bytes()[..(len - 1)].iter()) {
if *b == 0 {
return Err(InvalidInterfaceName::InteriorNul);
}
if *b == b'/' || b.is_ascii_whitespace() {
return Err(InvalidInterfaceName::InvalidChars);
}
*out = *b as i8;
}
Ok(Self(buf))
}
}
impl InterfaceName {
#[cfg(target_os = "linux")]
/// Creates a new [InterfaceName](Self).
///
/// ## Safety
///
/// The caller must ensure that `name` is a valid C string terminated by a NUL.
pub(crate) unsafe fn from_wg(name: RawInterfaceName) -> Self {
Self(name)
}
/// Returns a human-readable form of the device name.
///
/// Only use this when the interface name was constructed from a Rust string.
pub fn as_str_lossy(&self) -> Cow<'_, str> {
// SAFETY: These are C strings coming from wgctrl, so they are correctly NUL terminated.
unsafe { CStr::from_ptr(self.0.as_ptr()) }.to_string_lossy()
}
#[cfg(target_os = "linux")]
/// Returns a pointer to the inner byte buffer for FFI calls.
pub(crate) fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
#[cfg(target_os = "linux")]
/// Consumes this interface name, returning its raw byte buffer.
pub(crate) fn into_inner(self) -> RawInterfaceName {
self.0
}
}
impl fmt::Debug for InterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str_lossy())
}
}
impl fmt::Display for InterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str_lossy())
}
}
/// An interface name was bad.
#[derive(Debug, PartialEq)]
pub enum InvalidInterfaceName {
/// Provided name had an interior NUL byte.
InteriorNul,
/// Provided name was longer then the interface name length limit
/// of the system.
TooLong(usize),
// These checks are done in the kernel as well, but no reason to let bad names
// get that far: https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/utils.c?id=1f420318bda3cc62156e89e1b56d60cc744b48ad#n827.
/// Interface name was an empty string.
Empty,
/// Interface name contained a `/` or space character.
InvalidChars,
}
impl fmt::Display for InvalidInterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InteriorNul => f.write_str("interface name contained an interior NUL byte"),
Self::TooLong(size) => write!(
f,
"interface name was {} bytes long but the system's max is {}",
size,
libc::IFNAMSIZ
),
Self::Empty => f.write_str("an empty interface name was provided"),
Self::InvalidChars => f.write_str("interface name contained slash or space characters"),
}
}
}
impl From<InvalidInterfaceName> for std::io::Error {
fn from(e: InvalidInterfaceName) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
}
}
impl std::error::Error for InvalidInterfaceName {}
impl DeviceInfo { impl DeviceInfo {
/// Enumerates all WireGuard interfaces currently present in the system /// Enumerates all WireGuard interfaces currently present in the system
/// and returns their names. /// and returns their names.
@ -110,7 +242,7 @@ impl DeviceInfo {
/// You can use [`get_by_name`](DeviceInfo::get_by_name) to retrieve more /// You can use [`get_by_name`](DeviceInfo::get_by_name) to retrieve more
/// detailed information on each interface. /// detailed information on each interface.
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn enumerate() -> Result<Vec<String>, std::io::Error> { pub fn enumerate() -> Result<Vec<InterfaceName>, std::io::Error> {
if backends::kernel::exists() { if backends::kernel::exists() {
backends::kernel::enumerate() backends::kernel::enumerate()
} else { } else {
@ -119,12 +251,12 @@ impl DeviceInfo {
} }
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
pub fn enumerate() -> Result<Vec<String>, std::io::Error> { pub fn enumerate() -> Result<Vec<InterfaceName>, std::io::Error> {
crate::backends::userspace::enumerate() crate::backends::userspace::enumerate()
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn get_by_name(name: &str) -> Result<Self, std::io::Error> { pub fn get_by_name(name: &InterfaceName) -> Result<Self, std::io::Error> {
if backends::kernel::exists() { if backends::kernel::exists() {
backends::kernel::get_by_name(name) backends::kernel::get_by_name(name)
} else { } else {
@ -133,7 +265,7 @@ impl DeviceInfo {
} }
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
pub fn get_by_name(name: &str) -> Result<Self, std::io::Error> { pub fn get_by_name(name: &InterfaceName) -> Result<Self, std::io::Error> {
backends::userspace::get_by_name(name) backends::userspace::get_by_name(name)
} }
@ -150,7 +282,9 @@ impl DeviceInfo {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{DeviceConfigBuilder, KeyPair, PeerConfigBuilder}; use crate::{
DeviceConfigBuilder, InterfaceName, InvalidInterfaceName, KeyPair, PeerConfigBuilder,
};
const TEST_INTERFACE: &str = "wgctrl-test"; const TEST_INTERFACE: &str = "wgctrl-test";
use super::*; use super::*;
@ -166,9 +300,10 @@ mod tests {
for keypair in &keypairs { for keypair in &keypairs {
builder = builder.add_peer(PeerConfigBuilder::new(&keypair.public)) builder = builder.add_peer(PeerConfigBuilder::new(&keypair.public))
} }
builder.apply(TEST_INTERFACE).unwrap(); let interface = TEST_INTERFACE.parse().unwrap();
builder.apply(&interface).unwrap();
let device = DeviceInfo::get_by_name(TEST_INTERFACE).unwrap(); let device = DeviceInfo::get_by_name(&interface).unwrap();
for keypair in &keypairs { for keypair in &keypairs {
assert!(device assert!(device
@ -179,4 +314,24 @@ mod tests {
device.delete().unwrap(); device.delete().unwrap();
} }
#[test]
fn test_interface_names() {
assert!("wg-01".parse::<InterfaceName>().is_ok());
assert!("longer-nul\0".parse::<InterfaceName>().is_ok());
let invalid_names = &[
("", InvalidInterfaceName::Empty), // Empty Rust string
("\0", InvalidInterfaceName::Empty), // Empty C string
("ifname\0nul", InvalidInterfaceName::InteriorNul), // Contains interior NUL
("if name", InvalidInterfaceName::InvalidChars), // Contains a space
("ifna/me", InvalidInterfaceName::InvalidChars), // Contains a slash
("if na/me", InvalidInterfaceName::InvalidChars), // Contains a space and slash
("interfacelongname", InvalidInterfaceName::TooLong(17)), // Too long
];
for (name, expected) in invalid_names {
assert!(name.parse::<InterfaceName>().as_ref() == Err(expected))
}
}
} }