wgctrl-sys: Remove some unsafe in the kernel backend

Validates WireGuard interfaces against the linux specification for interface names.
Refactor userspace and other OSes to use InterfaceName
pull/37/head
BlackHoleFox 2021-04-08 20:28:37 -05:00 committed by GitHub
parent 67c69ecfa0
commit b1e1ff8f4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 307 additions and 142 deletions

View File

@ -6,6 +6,7 @@ use std::{
io::{Read, Seek, SeekFrom, Write},
path::Path,
};
use wgctrl::InterfaceName;
#[derive(Debug)]
pub struct DataStore {
@ -38,19 +39,21 @@ impl DataStore {
Ok(Self { file, contents })
}
fn _open(interface: &str, create: bool) -> Result<Self, Error> {
fn _open(interface: &InterfaceName, create: bool) -> Result<Self, Error> {
ensure_dirs_exist(&[*CLIENT_DATA_PATH])?;
Self::open_with_path(
CLIENT_DATA_PATH.join(interface).with_extension("json"),
CLIENT_DATA_PATH
.join(interface.to_string())
.with_extension("json"),
create,
)
}
pub fn open(interface: &str) -> Result<Self, Error> {
pub fn open(interface: &InterfaceName) -> Result<Self, Error> {
Self::_open(interface, false)
}
pub fn open_or_create(interface: &str) -> Result<Self, Error> {
pub fn open_or_create(interface: &InterfaceName) -> Result<Self, Error> {
Self::_open(interface, true)
}

View File

@ -14,7 +14,7 @@ use std::{
time::Duration,
};
use structopt::StructOpt;
use wgctrl::{DeviceConfigBuilder, DeviceInfo, PeerConfigBuilder, PeerInfo};
use wgctrl::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfigBuilder, PeerInfo};
mod data_store;
mod util;
@ -155,7 +155,11 @@ impl std::error::Error for ClientError {
}
}
fn update_hosts_file(interface: &str, hosts_path: PathBuf, peers: &Vec<Peer>) -> Result<(), Error> {
fn update_hosts_file(
interface: &InterfaceName,
hosts_path: PathBuf,
peers: &Vec<Peer>,
) -> Result<(), Error> {
println!(
"{} updating {} with the latest peers.",
"[*]".dimmed(),
@ -189,6 +193,8 @@ fn install(invite: &Path, hosts_file: Option<PathBuf>) -> Result<(), Error> {
return Err("An interface with this name already exists in innernet.".into());
}
let iface = iface.parse()?;
println!("{} bringing up the interface.", "[*]".dimmed());
wg::up(
&iface,
@ -267,7 +273,7 @@ fn install(invite: &Path, hosts_file: Option<PathBuf>) -> Result<(), Error> {
",
star = "[*]".dimmed(),
interface = iface.yellow(),
interface = iface.to_string().yellow(),
installed = "installed".green(),
systemctl_enable = "systemctl enable --now innernet@".yellow(),
);
@ -276,7 +282,7 @@ fn install(invite: &Path, hosts_file: Option<PathBuf>) -> Result<(), Error> {
}
fn up(
interface: &str,
interface: &InterfaceName,
loop_interval: Option<Duration>,
hosts_path: Option<PathBuf>,
) -> Result<(), Error> {
@ -292,7 +298,7 @@ fn up(
}
fn fetch(
interface: &str,
interface: &InterfaceName,
bring_up_interface: bool,
hosts_path: Option<PathBuf>,
) -> Result<(), Error> {
@ -398,7 +404,7 @@ fn fetch(
println!(
"\n{} updated interface {}\n",
"[*]".dimmed(),
interface.yellow()
interface.as_str_lossy().yellow()
);
} else {
println!("{}", " peers are already up to date.".green());
@ -410,7 +416,7 @@ fn fetch(
Ok(())
}
fn add_cidr(interface: &str) -> Result<(), Error> {
fn add_cidr(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs");
let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?;
@ -435,7 +441,7 @@ fn add_cidr(interface: &str) -> Result<(), Error> {
Ok(())
}
fn add_peer(interface: &str) -> Result<(), Error> {
fn add_peer(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs");
let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?;
@ -462,7 +468,7 @@ fn add_peer(interface: &str) -> Result<(), Error> {
Ok(())
}
fn enable_or_disable_peer(interface: &str, enable: bool) -> Result<(), Error> {
fn enable_or_disable_peer(interface: &InterfaceName, enable: bool) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching peers.");
let peers: Vec<Peer> = http_get(&server.internal_endpoint, "/admin/peers")?;
@ -482,7 +488,7 @@ fn enable_or_disable_peer(interface: &str, enable: bool) -> Result<(), Error> {
Ok(())
}
fn add_association(interface: &str) -> Result<(), Error> {
fn add_association(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs");
@ -504,7 +510,7 @@ fn add_association(interface: &str) -> Result<(), Error> {
Ok(())
}
fn delete_association(interface: &str) -> Result<(), Error> {
fn delete_association(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs");
@ -525,7 +531,7 @@ fn delete_association(interface: &str) -> Result<(), Error> {
Ok(())
}
fn list_associations(interface: &str) -> Result<(), Error> {
fn list_associations(interface: &InterfaceName) -> Result<(), Error> {
let InterfaceConfig { server, .. } = InterfaceConfig::from_interface(interface)?;
println!("Fetching CIDRs");
let cidrs: Vec<Cidr> = http_get(&server.internal_endpoint, "/admin/cidrs")?;
@ -555,7 +561,7 @@ fn list_associations(interface: &str) -> Result<(), Error> {
Ok(())
}
fn set_listen_port(interface: &str, unset: bool) -> Result<(), Error> {
fn set_listen_port(interface: &InterfaceName, unset: bool) -> Result<(), Error> {
let mut config = InterfaceConfig::from_interface(interface)?;
if let Some(listen_port) = prompts::set_listen_port(&config.interface, unset)? {
@ -572,7 +578,7 @@ fn set_listen_port(interface: &str, unset: bool) -> Result<(), Error> {
Ok(())
}
fn override_endpoint(interface: &str, unset: bool) -> Result<(), Error> {
fn override_endpoint(interface: &InterfaceName, unset: bool) -> Result<(), Error> {
let config = InterfaceConfig::from_interface(interface)?;
if !unset && config.interface.listen_port.is_none() {
println!(
@ -597,10 +603,8 @@ fn override_endpoint(interface: &str, unset: bool) -> Result<(), Error> {
}
fn show(short: bool, tree: bool, interface: Option<Interface>) -> Result<(), Error> {
let interfaces = interface.map_or_else(
|| DeviceInfo::enumerate(),
|interface| Ok(vec![interface.to_string()]),
)?;
let interfaces =
interface.map_or_else(|| DeviceInfo::enumerate(), |interface| Ok(vec![*interface]))?;
let devices = interfaces.into_iter().filter_map(|name| {
DataStore::open(&name)
@ -678,7 +682,7 @@ fn print_interface(device_info: &DeviceInfo, me: &Peer, short: bool) -> Result<(
.to_base64();
if short {
println!("{}", device_info.name.green().bold());
println!("{}", device_info.name.to_string().green().bold());
println!(
" {} {}: {} ({}...)",
"(you)".bold(),
@ -690,7 +694,7 @@ fn print_interface(device_info: &DeviceInfo, me: &Peer, short: bool) -> Result<(
println!(
"{}: {} ({}...)",
"interface".green().bold(),
device_info.name.green(),
device_info.name.to_string().green(),
public_key[..10].yellow()
);
if !short {

View File

@ -1,6 +1,6 @@
use crossbeam::channel::{self, select};
use dashmap::DashMap;
use wgctrl::DeviceInfo;
use wgctrl::{DeviceInfo, InterfaceName};
use std::{io, net::SocketAddr, sync::Arc, thread, time::Duration};
@ -18,7 +18,7 @@ impl std::ops::Deref for Endpoints {
}
impl Endpoints {
pub fn new(iface: &str) -> Result<Self, io::Error> {
pub fn new(iface: &InterfaceName) -> Result<Self, io::Error> {
let endpoints = Arc::new(DashMap::new());
let (stop_tx, stop_rx) = channel::bounded(1);

View File

@ -100,6 +100,9 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> {
(name, root_cidr)
});
// This probably won't error because of the `hostname_validator` regex.
let name = name.parse()?;
let endpoint: SocketAddr = conf.endpoint.unwrap_or_else(|| {
prompts::ask_endpoint()
.map_err(|_| println!("failed to get endpoint."))
@ -131,7 +134,7 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> {
config.write_to_path(&config_path)?;
let db_init_data = DbInitData {
root_cidr_name: name.clone(),
root_cidr_name: name.to_string(),
root_cidr,
server_cidr,
our_ip,
@ -176,7 +179,7 @@ pub fn init_wizard(conf: &ServerConfig) -> Result<(), Error> {
",
star = "[*]".dimmed(),
interface = name.yellow(),
interface = name.to_string().yellow(),
created = "created".green(),
wg_manage_server = "innernet-server".yellow(),
add_cidr = "add-cidr".yellow(),

View File

@ -19,7 +19,7 @@ use std::{
};
use structopt::StructOpt;
use warp::Filter;
use wgctrl::{DeviceConfigBuilder, DeviceInfo, PeerConfigBuilder};
use wgctrl::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfigBuilder};
pub mod api;
pub mod db;
@ -67,7 +67,7 @@ pub type Db = Arc<Mutex<Connection>>;
pub struct Context {
pub db: Db,
pub endpoints: Arc<Endpoints>,
pub interface: String,
pub interface: InterfaceName,
}
pub struct Session {
@ -140,10 +140,10 @@ impl ServerConfig {
.unwrap_or(*SERVER_DATABASE_DIR)
}
fn database_path(&self, interface: &str) -> PathBuf {
fn database_path(&self, interface: &InterfaceName) -> PathBuf {
PathBuf::new()
.join(self.database_dir())
.join(interface)
.join(interface.to_string())
.with_extension("db")
}
@ -153,10 +153,10 @@ impl ServerConfig {
.unwrap_or(*SERVER_CONFIG_DIR)
}
fn config_path(&self, interface: &str) -> PathBuf {
fn config_path(&self, interface: &InterfaceName) -> PathBuf {
PathBuf::new()
.join(self.config_dir())
.join(interface)
.join(interface.to_string())
.with_extension("conf")
}
}
@ -192,7 +192,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
fn open_database_connection(
interface: &str,
interface: &InterfaceName,
conf: &ServerConfig,
) -> Result<rusqlite::Connection, Box<dyn std::error::Error>> {
let database_path = conf.database_path(&interface);
@ -207,8 +207,8 @@ fn open_database_connection(
Ok(Connection::open(&database_path)?)
}
fn add_peer(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(&interface))?;
fn add_peer(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(interface))?;
let conn = open_database_connection(interface, conf)?;
let peers = DatabasePeer::list(&conn)?
.into_iter()
@ -245,7 +245,7 @@ fn add_peer(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
Ok(())
}
fn add_cidr(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
fn add_cidr(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> {
let conn = open_database_connection(interface, conf)?;
let cidrs = DatabaseCidr::list(&conn)?;
if let Some(cidr_request) = shared::prompts::add_cidr(&cidrs)? {
@ -268,9 +268,9 @@ fn add_cidr(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
Ok(())
}
async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(&interface))?;
let conn = open_database_connection(&interface, conf)?;
async fn serve(interface: &InterfaceName, conf: &ServerConfig) -> Result<(), Error> {
let config = ConfigFile::from_file(conf.config_path(interface))?;
let conn = open_database_connection(interface, conf)?;
// Foreign key constraints aren't on in SQLite by default. Enable.
conn.pragma_update(None, "foreign_keys", &1)?;
@ -282,7 +282,7 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
log::info!("bringing up interface.");
wg::up(
&interface,
interface,
&config.private_key,
IpNetwork::new(config.address, config.network_cidr_prefix)?,
Some(config.listen_port),
@ -300,7 +300,7 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
let db = Arc::new(Mutex::new(conn));
let context = Context {
db,
interface: interface.to_string(),
interface: *interface,
endpoints,
};
@ -334,11 +334,11 @@ async fn serve(interface: &str, conf: &ServerConfig) -> Result<(), Error> {
///
/// See https://github.com/tonarino/innernet/issues/26 for more details.
#[cfg(target_os = "linux")]
fn get_listener(addr: SocketAddr, interface: &str) -> Result<TcpListener, Error> {
fn get_listener(addr: SocketAddr, interface: &InterfaceName) -> Result<TcpListener, Error> {
let listener = TcpListener::bind(&addr)?;
listener.set_nonblocking(true)?;
let sock = socket2::Socket::from(listener);
sock.bind_device(Some(interface.as_bytes()))?;
sock.bind_device(Some(interface.as_str_lossy().as_bytes()))?;
Ok(sock.into())
}
@ -349,7 +349,7 @@ fn get_listener(addr: SocketAddr, interface: &str) -> Result<TcpListener, Error>
///
/// See https://github.com/tonarino/innernet/issues/26 for more details.
#[cfg(not(target_os = "linux"))]
fn get_listener(addr: SocketAddr, _interface: &str) -> Result<TcpListener, Error> {
fn get_listener(addr: SocketAddr, _interface: &InterfaceName) -> Result<TcpListener, Error> {
let listener = TcpListener::bind(&addr)?;
listener.set_nonblocking(true)?;
Ok(listener)

View File

@ -12,7 +12,7 @@ use shared::{Cidr, CidrContents, PeerContents};
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use tempfile::TempDir;
use warp::test::RequestBuilder;
use wgctrl::KeyPair;
use wgctrl::{InterfaceName, KeyPair};
pub const ROOT_CIDR: &str = "10.80.0.0/15";
pub const SERVER_CIDR: &str = "10.80.0.1/32";
@ -45,7 +45,7 @@ pub const USER2_PEER_ID: i64 = 6;
pub struct Server {
pub db: Arc<Mutex<Connection>>,
endpoints: Arc<Endpoints>,
interface: String,
interface: InterfaceName,
conf: ServerConfig,
// The directory will be removed during destruction.
_test_dir: TempDir,
@ -69,6 +69,7 @@ impl Server {
};
init_wizard(&conf).map_err(|_| anyhow!("init_wizard failed"))?;
let interface = interface.parse().unwrap();
// Add developer CIDR and user CIDR and some peers for testing.
let db = Connection::open(&conf.database_path(&interface))?;
db.pragma_update(None, "foreign_keys", &1)?;

View File

@ -9,6 +9,7 @@ use std::{
os::unix::fs::PermissionsExt,
path::{Path, PathBuf},
};
use wgctrl::InterfaceName;
#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "kebab-case")]
@ -92,7 +93,7 @@ impl InterfaceConfig {
}
/// Overwrites the config file if it already exists.
pub fn write_to_interface(&self, interface: &str) -> Result<PathBuf, Error> {
pub fn write_to_interface(&self, interface: &InterfaceName) -> Result<PathBuf, Error> {
let path = Self::build_config_file_path(interface)?;
File::create(&path)
.with_path(&path)?
@ -104,13 +105,15 @@ impl InterfaceConfig {
Ok(toml::from_slice(&std::fs::read(&path).with_path(path)?)?)
}
pub fn from_interface(interface: &str) -> Result<Self, Error> {
pub fn from_interface(interface: &InterfaceName) -> Result<Self, Error> {
Self::from_file(Self::build_config_file_path(interface)?)
}
fn build_config_file_path(interface: &str) -> Result<PathBuf, Error> {
fn build_config_file_path(interface: &InterfaceName) -> Result<PathBuf, Error> {
ensure_dirs_exist(&[*CLIENT_CONFIG_PATH])?;
Ok(CLIENT_CONFIG_PATH.join(interface).with_extension("conf"))
Ok(CLIENT_CONFIG_PATH
.join(interface.to_string())
.with_extension("conf"))
}
}

View File

@ -13,7 +13,7 @@ use std::{
str::FromStr,
time::Duration,
};
use wgctrl::{Key, PeerConfig, PeerConfigBuilder};
use wgctrl::{InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder};
pub mod interface_config;
pub mod prompts;
@ -65,23 +65,24 @@ impl std::error::Error for WrappedIoError {}
#[derive(Debug, Clone)]
pub struct Interface {
name: String,
name: InterfaceName,
}
impl FromStr for Interface {
type Err = &'static str;
type Err = String;
fn from_str(name: &str) -> Result<Self, Self::Err> {
let s = name.to_string();
hostname_validator(&s)?;
Ok(Self {
name: name.to_string(),
})
let name = name.to_string();
hostname_validator(&name)?;
let name = name
.parse()
.map_err(|e: InvalidInterfaceName| e.to_string())?;
Ok(Self { name })
}
}
impl Deref for Interface {
type Target = str;
type Target = InterfaceName;
fn deref(&self) -> &Self::Target {
&self.name

View File

@ -9,7 +9,7 @@ use ipnetwork::IpNetwork;
use lazy_static::lazy_static;
use regex::Regex;
use std::net::{IpAddr, SocketAddr};
use wgctrl::KeyPair;
use wgctrl::{InterfaceName, KeyPair};
lazy_static! {
static ref THEME: ColorfulTheme = ColorfulTheme::default();
@ -239,7 +239,7 @@ pub fn enable_or_disable_peer(peers: &[Peer], enable: bool) -> Result<Option<Pee
/// Confirm and write a innernet invitation file after a peer has been created.
pub fn save_peer_invitation(
network_name: &str,
network_name: &InterfaceName,
peer: &Peer,
server_peer: &Peer,
root_cidr: &Cidr,

View File

@ -4,7 +4,7 @@ use std::{
net::{IpAddr, SocketAddr},
process::{self, Command},
};
use wgctrl::{DeviceConfigBuilder, PeerConfigBuilder};
use wgctrl::{DeviceConfigBuilder, InterfaceName, PeerConfigBuilder};
fn cmd(bin: &str, args: &[&str]) -> Result<process::Output, Error> {
let output = Command::new(bin).args(args).output()?;
@ -22,8 +22,9 @@ fn cmd(bin: &str, args: &[&str]) -> Result<process::Output, Error> {
}
#[cfg(target_os = "macos")]
pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> {
let real_interface = wgctrl::backends::userspace::resolve_tun(interface).with_str(interface)?;
pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), Error> {
let real_interface =
wgctrl::backends::userspace::resolve_tun(interface).with_str(interface.to_string())?;
if addr.is_ipv4() {
cmd(
@ -47,20 +48,21 @@ pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> {
}
#[cfg(target_os = "linux")]
pub fn set_addr(interface: &str, addr: IpNetwork) -> Result<(), Error> {
pub fn set_addr(interface: &InterfaceName, addr: IpNetwork) -> Result<(), Error> {
let interface = interface.to_string();
cmd(
"ip",
&["address", "replace", &addr.to_string(), "dev", interface],
&["address", "replace", &addr.to_string(), "dev", &interface],
)?;
let _ = cmd(
"ip",
&["link", "set", "mtu", "1420", "up", "dev", interface],
&["link", "set", "mtu", "1420", "up", "dev", &interface],
);
Ok(())
}
pub fn up(
interface: &str,
interface: &InterfaceName,
private_key: &str,
address: IpNetwork,
listen_port: Option<u16>,
@ -85,7 +87,7 @@ pub fn up(
Ok(())
}
pub fn set_listen_port(interface: &str, listen_port: Option<u16>) -> Result<(), Error> {
pub fn set_listen_port(interface: &InterfaceName, listen_port: Option<u16>) -> Result<(), Error> {
let mut device = DeviceConfigBuilder::new();
if let Some(listen_port) = listen_port {
device = device.set_listen_port(listen_port);
@ -98,24 +100,24 @@ pub fn set_listen_port(interface: &str, listen_port: Option<u16>) -> Result<(),
}
#[cfg(target_os = "linux")]
pub fn down(interface: &str) -> Result<(), Error> {
Ok(wgctrl::delete_interface(interface).with_str(interface)?)
pub fn down(interface: &InterfaceName) -> Result<(), Error> {
Ok(wgctrl::delete_interface(&interface).with_str(interface.to_string())?)
}
#[cfg(not(target_os = "linux"))]
pub fn down(interface: &str) -> Result<(), Error> {
pub fn down(interface: &InterfaceName) -> Result<(), Error> {
wgctrl::backends::userspace::delete_interface(interface)
.with_str(interface)
.with_str(interface.to_string())
.map_err(Error::from)
}
/// Add a route in the OS's routing table to get traffic flowing through this interface.
/// Returns an error if the process doesn't exit successfully, otherwise returns
/// true if the route was changed, false if the route already exists.
pub fn add_route(interface: &str, cidr: IpNetwork) -> Result<bool, Error> {
pub fn add_route(interface: &InterfaceName, cidr: IpNetwork) -> Result<bool, Error> {
if cfg!(target_os = "macos") {
let real_interface =
wgctrl::backends::userspace::resolve_tun(interface).with_str(interface)?;
wgctrl::backends::userspace::resolve_tun(interface).with_str(interface.to_string())?;
let output = cmd(
"route",
&[
@ -141,7 +143,13 @@ pub fn add_route(interface: &str, cidr: IpNetwork) -> Result<bool, Error> {
// TODO(mcginty): use the netlink interface on linux to modify routing table.
let _ = cmd(
"ip",
&["route", "add", &cidr.to_string(), "dev", &interface],
&[
"route",
"add",
&cidr.to_string(),
"dev",
&interface.to_string(),
],
);
Ok(false)
}

View File

@ -1,6 +1,6 @@
use crate::{
device::AllowedIp, DeviceConfigBuilder, DeviceInfo, InvalidKey, PeerConfig, PeerConfigBuilder,
PeerInfo, PeerStats,
device::AllowedIp, DeviceConfigBuilder, DeviceInfo, InterfaceName, InvalidInterfaceName,
InvalidKey, PeerConfig, PeerConfigBuilder, PeerInfo, PeerStats,
};
use wgctrl_sys::{timespec64, wg_device_flags as wgdf, wg_peer_flags as wgpf};
@ -71,8 +71,10 @@ impl<'a> From<&'a wgctrl_sys::wg_peer> for PeerInfo {
impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo {
fn from(raw: &wgctrl_sys::wg_device) -> DeviceInfo {
// SAFETY: The name string buffer came directly from wgctrl so its NUL terminated.
let name = unsafe { InterfaceName::from_wg(raw.name) };
DeviceInfo {
name: parse_device_name(raw.name),
name,
public_key: if (raw.flags & wgdf::WGDEVICE_HAS_PUBLIC_KEY).0 > 0 {
Some(Key::from_raw(raw.public_key))
} else {
@ -98,15 +100,6 @@ impl<'a> From<&'a wgctrl_sys::wg_device> for DeviceInfo {
}
}
fn parse_device_name(name: [c_char; 16]) -> String {
let name: &[u8; 16] = unsafe { &*((&name) as *const _ as *const [u8; 16]) };
let idx: usize = name
.iter()
.position(|x| *x == 0)
.expect("Interface name too long?");
unsafe { str::from_utf8_unchecked(&name[..idx]) }.to_owned()
}
fn parse_peers(dev: &wgctrl_sys::wg_device) -> Vec<PeerInfo> {
let mut result = Vec::new();
@ -297,15 +290,6 @@ fn encode_peers(
(first_peer, last_peer)
}
fn encode_name(name: &str) -> [c_char; 16] {
let slice = unsafe { &*(name.as_bytes() as *const _ as *const [c_char]) };
let mut result = [c_char::default(); 16];
result[..slice.len()].copy_from_slice(slice);
result
}
pub fn exists() -> bool {
// Try to load the wireguard module if it isn't already.
// This is only called once per lifetime of the process.
@ -320,7 +304,7 @@ pub fn exists() -> bool {
Path::new("/sys/module/wireguard").is_dir()
}
pub fn enumerate() -> Result<Vec<String>, io::Error> {
pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
let base = unsafe { wgctrl_sys::wg_list_device_names() };
if base.is_null() {
@ -340,7 +324,12 @@ pub fn enumerate() -> Result<Vec<String>, io::Error> {
}
current = unsafe { current.add(len + 1) };
result.push(unsafe { str::from_utf8_unchecked(next_dev) }.to_owned());
let interface: InterfaceName = str::from_utf8(next_dev)
.map_err(|_| InvalidInterfaceName::InvalidChars)?
.parse()?;
result.push(interface);
}
unsafe { libc::free(base as *mut libc::c_void) };
@ -348,18 +337,17 @@ pub fn enumerate() -> Result<Vec<String>, io::Error> {
Ok(result)
}
pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> {
pub fn apply(builder: DeviceConfigBuilder, iface: &InterfaceName) -> io::Result<()> {
let (first_peer, last_peer) = encode_peers(builder.peers);
let iface_str = CString::new(iface)?;
let result = unsafe { wgctrl_sys::wg_add_device(iface_str.as_ptr()) };
let result = unsafe { wgctrl_sys::wg_add_device(iface.as_ptr()) };
match result {
0 | -17 => {},
_ => return Err(io::Error::last_os_error()),
};
let mut wg_device = Box::new(wgctrl_sys::wg_device {
name: encode_name(iface),
name: iface.into_inner(),
ifindex: 0,
public_key: wgctrl_sys::wg_key::default(),
private_key: wgctrl_sys::wg_key::default(),
@ -406,15 +394,13 @@ pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> {
}
}
pub fn get_by_name(name: &str) -> Result<DeviceInfo, io::Error> {
pub fn get_by_name(name: &InterfaceName) -> Result<DeviceInfo, io::Error> {
let mut device: *mut wgctrl_sys::wg_device = ptr::null_mut();
let cs = CString::new(name)?;
let result = unsafe {
wgctrl_sys::wg_get_device(
(&mut device) as *mut _ as *mut *mut wgctrl_sys::wg_device,
cs.as_ptr(),
name.as_ptr(),
)
};
@ -429,9 +415,8 @@ pub fn get_by_name(name: &str) -> Result<DeviceInfo, io::Error> {
result
}
pub fn delete_interface(iface: &str) -> io::Result<()> {
let iface_str = CString::new(iface)?;
let result = unsafe { wgctrl_sys::wg_del_device(iface_str.as_ptr()) };
pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
let result = unsafe { wgctrl_sys::wg_del_device(iface.as_ptr()) };
if result == 0 {
Ok(())

View File

@ -1,11 +1,11 @@
use crate::{DeviceConfigBuilder, DeviceInfo, PeerConfig, PeerInfo, PeerStats};
use crate::{DeviceConfigBuilder, DeviceInfo, InterfaceName, PeerConfig, PeerInfo, PeerStats};
#[cfg(target_os = "linux")]
use crate::Key;
use std::{
fs, io,
io::{prelude::*, BufReader},
fs,
io::{self, prelude::*, BufReader},
os::unix::net::UnixStream,
path::{Path, PathBuf},
process::Command,
@ -28,31 +28,31 @@ fn get_base_folder() -> io::Result<PathBuf> {
}
}
fn get_namefile(name: &str) -> io::Result<PathBuf> {
Ok(get_base_folder()?.join(&format!("{}.name", name)))
fn get_namefile(name: &InterfaceName) -> io::Result<PathBuf> {
Ok(get_base_folder()?.join(&format!("{}.name", name.as_str_lossy())))
}
fn get_socketfile(name: &str) -> io::Result<PathBuf> {
fn get_socketfile(name: &InterfaceName) -> io::Result<PathBuf> {
Ok(get_base_folder()?.join(&format!("{}.sock", resolve_tun(name)?)))
}
fn open_socket(name: &str) -> io::Result<UnixStream> {
fn open_socket(name: &InterfaceName) -> io::Result<UnixStream> {
UnixStream::connect(get_socketfile(name)?)
}
pub fn resolve_tun(name: &str) -> io::Result<String> {
pub fn resolve_tun(name: &InterfaceName) -> io::Result<String> {
let namefile = get_namefile(name)?;
Ok(fs::read_to_string(namefile)?.trim().to_string())
}
pub fn delete_interface(name: &str) -> io::Result<()> {
pub fn delete_interface(name: &InterfaceName) -> io::Result<()> {
fs::remove_file(get_socketfile(name)?).ok();
fs::remove_file(get_namefile(name)?).ok();
Ok(())
}
pub fn enumerate() -> Result<Vec<String>, io::Error> {
pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
use std::ffi::OsStr;
let mut interfaces = vec![];
@ -61,7 +61,7 @@ pub fn enumerate() -> Result<Vec<String>, io::Error> {
if path.extension() == Some(OsStr::new("name")) {
let stem = path.file_stem().map(|stem| stem.to_str()).flatten();
if let Some(name) = stem {
interfaces.push(name.to_string());
interfaces.push(name.parse()?);
}
}
}
@ -100,9 +100,10 @@ impl From<ConfigParser> for DeviceInfo {
}
impl ConfigParser {
fn new(name: &str) -> Self {
/// Returns `None` if an invalid device name was provided.
fn new(name: &InterfaceName) -> Self {
let device_info = DeviceInfo {
name: name.to_string(),
name: *name,
public_key: None,
private_key: None,
fwmark: None,
@ -228,13 +229,14 @@ impl ConfigParser {
}
}
pub fn get_by_name(name: &str) -> Result<DeviceInfo, io::Error> {
pub fn get_by_name(name: &InterfaceName) -> Result<DeviceInfo, io::Error> {
let mut sock = open_socket(name)?;
sock.write_all(b"get=1\n\n")?;
let mut reader = BufReader::new(sock);
let mut buf = String::new();
let mut parser = ConfigParser::new(name);
loop {
match reader.read_line(&mut buf)? {
0 | 1 if buf == "\n" => break,
@ -261,7 +263,7 @@ fn get_userspace_implementation() -> String {
.unwrap_or_else(|_| "wireguard-go".to_string())
}
pub fn apply(builder: DeviceConfigBuilder, iface: &str) -> io::Result<()> {
pub fn apply(builder: DeviceConfigBuilder, iface: &InterfaceName) -> io::Result<()> {
// If we can't open a configuration socket to an existing interface, try starting it.
let mut sock = match open_socket(iface) {
Err(_) => {

View File

@ -1,6 +1,6 @@
use crate::{
backends,
device::{AllowedIp, PeerConfig},
device::{AllowedIp, InterfaceName, PeerConfig},
key::{Key, KeyPair},
};
@ -35,7 +35,7 @@ use std::{
/// peer.set_endpoint(server_addr)
/// .replace_allowed_ips()
/// .allow_all_ips()
/// }).apply("wg-example");
/// }).apply(&"wg-example".parse().unwrap());
///
/// println!("Send these keys to your peer: {:#?}", peer_keypair);
///
@ -171,16 +171,16 @@ impl DeviceConfigBuilder {
///
/// An interface with the provided name will be created if one does not exist already.
#[cfg(target_os = "linux")]
pub fn apply(self, iface: &str) -> io::Result<()> {
pub fn apply(self, iface: &InterfaceName) -> io::Result<()> {
if backends::kernel::exists() {
backends::kernel::apply(self, iface)
backends::kernel::apply(self, &iface)
} else {
backends::userspace::apply(self, iface)
}
}
#[cfg(not(target_os = "linux"))]
pub fn apply(self, iface: &str) -> io::Result<()> {
pub fn apply(self, iface: &InterfaceName) -> io::Result<()> {
backends::userspace::apply(self, iface)
}
}
@ -215,7 +215,7 @@ impl Default for DeviceConfigBuilder {
/// .add_allowed_ip("192.168.1.2".parse()?, 32);
///
/// // update our existing configuration with the new peer
/// DeviceConfigBuilder::new().add_peer(peer).apply("wg-example");
/// DeviceConfigBuilder::new().add_peer(peer).apply(&"wg-example".parse().unwrap());
///
/// println!("Send these keys to your peer: {:#?}", peer_keypair);
///
@ -353,6 +353,6 @@ impl PeerConfigBuilder {
/// Deletes an existing WireGuard interface by name.
#[cfg(target_os = "linux")]
pub fn delete_interface(iface: &str) -> io::Result<()> {
pub fn delete_interface(iface: &InterfaceName) -> io::Result<()> {
backends::kernel::delete_interface(iface)
}

View File

@ -1,7 +1,13 @@
use libc::c_char;
use crate::{backends, key::Key};
use std::{
borrow::Cow,
ffi::CStr,
fmt,
net::{IpAddr, SocketAddr},
str::FromStr,
time::SystemTime,
};
@ -86,7 +92,7 @@ pub struct PeerInfo {
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct DeviceInfo {
/// The interface name of this device
pub name: String,
pub name: InterfaceName,
/// The public encryption key of this interface (if present)
pub public_key: Option<Key>,
/// The private encryption key of this interface (if present)
@ -103,6 +109,132 @@ pub struct DeviceInfo {
pub(crate) __cant_construct_me: (),
}
type RawInterfaceName = [c_char; libc::IFNAMSIZ];
/// The name of a Wireguard interface device.
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct InterfaceName(RawInterfaceName);
impl FromStr for InterfaceName {
type Err = InvalidInterfaceName;
/// Attempts to parse a Rust string as a valid Linux interface name.
///
/// Extra validation logic ported from [iproute2](https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/utils.c#n827)
fn from_str(name: &str) -> Result<Self, InvalidInterfaceName> {
let len = name.len();
// Ensure its short enough to include a trailing NUL
if len > (libc::IFNAMSIZ - 1) {
return Err(InvalidInterfaceName::TooLong(len));
}
if len == 0 || name.trim_start_matches('\0').is_empty() {
return Err(InvalidInterfaceName::Empty);
}
let mut buf = [c_char::default(); libc::IFNAMSIZ];
// Check for interior NULs and other invalid characters.
for (out, b) in buf.iter_mut().zip(name.as_bytes()[..(len - 1)].iter()) {
if *b == 0 {
return Err(InvalidInterfaceName::InteriorNul);
}
if *b == b'/' || b.is_ascii_whitespace() {
return Err(InvalidInterfaceName::InvalidChars);
}
*out = *b as i8;
}
Ok(Self(buf))
}
}
impl InterfaceName {
#[cfg(target_os = "linux")]
/// Creates a new [InterfaceName](Self).
///
/// ## Safety
///
/// The caller must ensure that `name` is a valid C string terminated by a NUL.
pub(crate) unsafe fn from_wg(name: RawInterfaceName) -> Self {
Self(name)
}
/// Returns a human-readable form of the device name.
///
/// Only use this when the interface name was constructed from a Rust string.
pub fn as_str_lossy(&self) -> Cow<'_, str> {
// SAFETY: These are C strings coming from wgctrl, so they are correctly NUL terminated.
unsafe { CStr::from_ptr(self.0.as_ptr()) }.to_string_lossy()
}
#[cfg(target_os = "linux")]
/// Returns a pointer to the inner byte buffer for FFI calls.
pub(crate) fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
#[cfg(target_os = "linux")]
/// Consumes this interface name, returning its raw byte buffer.
pub(crate) fn into_inner(self) -> RawInterfaceName {
self.0
}
}
impl fmt::Debug for InterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str_lossy())
}
}
impl fmt::Display for InterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str_lossy())
}
}
/// An interface name was bad.
#[derive(Debug, PartialEq)]
pub enum InvalidInterfaceName {
/// Provided name had an interior NUL byte.
InteriorNul,
/// Provided name was longer then the interface name length limit
/// of the system.
TooLong(usize),
// These checks are done in the kernel as well, but no reason to let bad names
// get that far: https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/utils.c?id=1f420318bda3cc62156e89e1b56d60cc744b48ad#n827.
/// Interface name was an empty string.
Empty,
/// Interface name contained a `/` or space character.
InvalidChars,
}
impl fmt::Display for InvalidInterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InteriorNul => f.write_str("interface name contained an interior NUL byte"),
Self::TooLong(size) => write!(
f,
"interface name was {} bytes long but the system's max is {}",
size,
libc::IFNAMSIZ
),
Self::Empty => f.write_str("an empty interface name was provided"),
Self::InvalidChars => f.write_str("interface name contained slash or space characters"),
}
}
}
impl From<InvalidInterfaceName> for std::io::Error {
fn from(e: InvalidInterfaceName) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
}
}
impl std::error::Error for InvalidInterfaceName {}
impl DeviceInfo {
/// Enumerates all WireGuard interfaces currently present in the system
/// and returns their names.
@ -110,7 +242,7 @@ impl DeviceInfo {
/// You can use [`get_by_name`](DeviceInfo::get_by_name) to retrieve more
/// detailed information on each interface.
#[cfg(target_os = "linux")]
pub fn enumerate() -> Result<Vec<String>, std::io::Error> {
pub fn enumerate() -> Result<Vec<InterfaceName>, std::io::Error> {
if backends::kernel::exists() {
backends::kernel::enumerate()
} else {
@ -119,12 +251,12 @@ impl DeviceInfo {
}
#[cfg(not(target_os = "linux"))]
pub fn enumerate() -> Result<Vec<String>, std::io::Error> {
pub fn enumerate() -> Result<Vec<InterfaceName>, std::io::Error> {
crate::backends::userspace::enumerate()
}
#[cfg(target_os = "linux")]
pub fn get_by_name(name: &str) -> Result<Self, std::io::Error> {
pub fn get_by_name(name: &InterfaceName) -> Result<Self, std::io::Error> {
if backends::kernel::exists() {
backends::kernel::get_by_name(name)
} else {
@ -133,7 +265,7 @@ impl DeviceInfo {
}
#[cfg(not(target_os = "linux"))]
pub fn get_by_name(name: &str) -> Result<Self, std::io::Error> {
pub fn get_by_name(name: &InterfaceName) -> Result<Self, std::io::Error> {
backends::userspace::get_by_name(name)
}
@ -150,7 +282,9 @@ impl DeviceInfo {
#[cfg(test)]
mod tests {
use crate::{DeviceConfigBuilder, KeyPair, PeerConfigBuilder};
use crate::{
DeviceConfigBuilder, InterfaceName, InvalidInterfaceName, KeyPair, PeerConfigBuilder,
};
const TEST_INTERFACE: &str = "wgctrl-test";
use super::*;
@ -166,9 +300,10 @@ mod tests {
for keypair in &keypairs {
builder = builder.add_peer(PeerConfigBuilder::new(&keypair.public))
}
builder.apply(TEST_INTERFACE).unwrap();
let interface = TEST_INTERFACE.parse().unwrap();
builder.apply(&interface).unwrap();
let device = DeviceInfo::get_by_name(TEST_INTERFACE).unwrap();
let device = DeviceInfo::get_by_name(&interface).unwrap();
for keypair in &keypairs {
assert!(device
@ -179,4 +314,24 @@ mod tests {
device.delete().unwrap();
}
#[test]
fn test_interface_names() {
assert!("wg-01".parse::<InterfaceName>().is_ok());
assert!("longer-nul\0".parse::<InterfaceName>().is_ok());
let invalid_names = &[
("", InvalidInterfaceName::Empty), // Empty Rust string
("\0", InvalidInterfaceName::Empty), // Empty C string
("ifname\0nul", InvalidInterfaceName::InteriorNul), // Contains interior NUL
("if name", InvalidInterfaceName::InvalidChars), // Contains a space
("ifna/me", InvalidInterfaceName::InvalidChars), // Contains a slash
("if na/me", InvalidInterfaceName::InvalidChars), // Contains a space and slash
("interfacelongname", InvalidInterfaceName::TooLong(17)), // Too long
];
for (name, expected) in invalid_names {
assert!(name.parse::<InterfaceName>().as_ref() == Err(expected))
}
}
}