From 05d78eb253082e44e56d2db2d290f4195b75e154 Mon Sep 17 00:00:00 2001 From: Jake McGinty Date: Sun, 11 Apr 2021 14:56:47 +0900 Subject: [PATCH] shared: add types module --- shared/src/lib.rs | 394 +------------------------------------------ shared/src/types.rs | 396 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 399 insertions(+), 391 deletions(-) create mode 100644 shared/src/types.rs diff --git a/shared/src/lib.rs b/shared/src/lib.rs index 4e5526b..026b1c3 100644 --- a/shared/src/lib.rs +++ b/shared/src/lib.rs @@ -1,25 +1,20 @@ use colored::*; -use ipnetwork::IpNetwork; use lazy_static::lazy_static; -use prompts::hostname_validator; -use serde::{Deserialize, Serialize}; use std::{ - fmt::{Display, Formatter}, fs::{self, File}, io, - net::{IpAddr, SocketAddr}, - ops::Deref, os::unix::fs::PermissionsExt, path::Path, - str::FromStr, time::Duration, }; -use wgctrl::{InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder}; pub mod interface_config; pub mod prompts; +pub mod types; pub mod wg; +pub use types::*; + lazy_static! { pub static ref CLIENT_CONFIG_PATH: &'static Path = Path::new("/etc/innernet"); pub static ref CLIENT_DATA_PATH: &'static Path = Path::new("/var/lib/innernet"); @@ -33,330 +28,6 @@ pub const INNERNET_PUBKEY_HEADER: &str = "X-Innernet-Server-Key"; pub type Error = Box; -pub trait IoErrorContext { - fn with_path>(self, path: P) -> Result; - fn with_str>(self, context: S) -> Result; -} - -impl IoErrorContext for Result { - fn with_path>(self, path: P) -> Result { - self.with_str(path.as_ref().to_string_lossy()) - } - - fn with_str>(self, context: S) -> Result { - self.map_err(|e| WrappedIoError { - io_error: e, - context: context.into(), - }) - } -} - -#[derive(Debug)] -pub struct WrappedIoError { - io_error: std::io::Error, - context: String, -} - -impl std::fmt::Display for WrappedIoError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - write!(f, "{} - {}", self.context, self.io_error) - } -} - -impl std::error::Error for WrappedIoError {} - -#[derive(Debug, Clone)] -pub struct Interface { - name: InterfaceName, -} - -impl FromStr for Interface { - type Err = String; - - fn from_str(name: &str) -> Result { - let name = name.to_string(); - hostname_validator(&name)?; - let name = name - .parse() - .map_err(|e: InvalidInterfaceName| e.to_string())?; - Ok(Self { name }) - } -} - -impl Deref for Interface { - type Target = InterfaceName; - - fn deref(&self) -> &Self::Target { - &self.name - } -} - -#[derive(Deserialize, Serialize, Debug)] -#[serde(tag = "option", content = "content")] -pub enum EndpointContents { - Set(SocketAddr), - Unset, -} - -impl Into> for EndpointContents { - fn into(self) -> Option { - match self { - Self::Set(addr) => Some(addr), - Self::Unset => None, - } - } -} - -impl From> for EndpointContents { - fn from(option: Option) -> Self { - match option { - Some(addr) => Self::Set(addr), - None => Self::Unset, - } - } -} - -#[derive(Deserialize, Serialize, Debug)] -pub struct AssociationContents { - pub cidr_id_1: i64, - pub cidr_id_2: i64, -} - -#[derive(Deserialize, Serialize, Debug)] -pub struct Association { - pub id: i64, - - #[serde(flatten)] - pub contents: AssociationContents, -} - -impl Deref for Association { - type Target = AssociationContents; - - fn deref(&self) -> &Self::Target { - &self.contents - } -} - -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] -pub struct CidrContents { - pub name: String, - pub cidr: IpNetwork, - pub parent: Option, -} - -impl Deref for CidrContents { - type Target = IpNetwork; - - fn deref(&self) -> &Self::Target { - &self.cidr - } -} - -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] -pub struct Cidr { - pub id: i64, - - #[serde(flatten)] - pub contents: CidrContents, -} - -impl Deref for Cidr { - type Target = CidrContents; - - fn deref(&self) -> &Self::Target { - &self.contents - } -} - -pub struct CidrTree<'a> { - cidrs: &'a [Cidr], - contents: &'a Cidr, -} - -impl<'a> std::ops::Deref for CidrTree<'a> { - type Target = Cidr; - - fn deref(&self) -> &Self::Target { - self.contents - } -} - -impl<'a> CidrTree<'a> { - pub fn new(cidrs: &'a [Cidr]) -> Self { - let root = cidrs - .iter() - .min_by_key(|c| c.cidr.prefix()) - .expect("failed to find root CIDR"); - Self { - cidrs, - contents: root, - } - } - - pub fn children(&self) -> impl Iterator { - self.cidrs - .iter() - .filter(move |c| c.parent == Some(self.contents.id)) - .map(move |c| Self { - cidrs: self.cidrs, - contents: c, - }) - } - - pub fn leaves(&self) -> Vec { - let mut leaves = vec![]; - for cidr in self.cidrs { - if !self.cidrs.iter().any(|c| c.parent == Some(cidr.id)) { - leaves.push(cidr.clone()); - } - } - leaves - } -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct RedeemContents { - pub public_key: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct PeerContents { - pub name: String, - pub ip: IpAddr, - pub cidr_id: i64, - pub public_key: String, - pub endpoint: Option, - pub persistent_keepalive_interval: Option, - pub is_admin: bool, - pub is_disabled: bool, - pub is_redeemed: bool, -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct Peer { - pub id: i64, - - #[serde(flatten)] - pub contents: PeerContents, -} - -impl Deref for Peer { - type Target = PeerContents; - - fn deref(&self) -> &Self::Target { - &self.contents - } -} - -impl Display for Peer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} ({})", &self.name, &self.public_key) - } -} - -#[derive(Debug, PartialEq)] -pub struct PeerDiff { - pub public_key: String, - pub endpoint: Option, - pub persistent_keepalive_interval: Option, - pub is_disabled: bool, -} - -impl Peer { - pub fn diff(&self, peer: &PeerConfig) -> Option { - assert_eq!(self.public_key, peer.public_key.to_base64()); - - let endpoint_diff = if peer.endpoint != self.endpoint { - self.endpoint - } else { - None - }; - - let keepalive_diff = - if peer.persistent_keepalive_interval != self.persistent_keepalive_interval { - self.persistent_keepalive_interval - } else { - None - }; - - if endpoint_diff.is_none() && keepalive_diff.is_none() { - None - } else { - Some(PeerDiff { - public_key: self.public_key.clone(), - endpoint: endpoint_diff, - persistent_keepalive_interval: keepalive_diff, - is_disabled: self.is_disabled, - }) - } - } -} - -impl<'a> From<&'a Peer> for PeerConfigBuilder { - fn from(peer: &Peer) -> Self { - let builder = PeerConfigBuilder::new(&Key::from_base64(&peer.public_key).unwrap()) - .replace_allowed_ips() - .add_allowed_ip(peer.ip, if peer.ip.is_ipv4() { 32 } else { 128 }); - - let builder = if peer.is_disabled { - builder.remove() - } else { - builder - }; - - let builder = if let Some(interval) = peer.persistent_keepalive_interval { - builder.set_persistent_keepalive_interval(interval) - } else { - builder - }; - - if let Some(endpoint) = peer.endpoint { - builder.set_endpoint(endpoint) - } else { - builder - } - } -} - -impl<'a> From<&'a PeerDiff> for PeerConfigBuilder { - fn from(peer: &PeerDiff) -> Self { - let builder = PeerConfigBuilder::new(&Key::from_base64(&peer.public_key).unwrap()); - - let builder = if peer.is_disabled { - builder.remove() - } else { - builder - }; - - let builder = if let Some(interval) = peer.persistent_keepalive_interval { - builder.set_persistent_keepalive_interval(interval) - } else { - builder - }; - - if let Some(endpoint) = peer.endpoint { - builder.set_endpoint(endpoint) - } else { - builder - } - } -} - -/// This model is sent as a response to the /state endpoint, and is meant -/// to include all the data a client needs to update its WireGuard interface. -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct State { - /// This list will be only the peers visible to the user requesting this - /// information, not including disabled peers or peers from other CIDRs - /// that the user's CIDR is not authorized to communicate with. - pub peers: Vec, - - /// At the moment, this is all CIDRs, regardless of whether the peer is - /// eligible to communicate with them or not. - pub cidrs: Vec, -} - pub static WG_MANAGE_DIR: &str = "/etc/innernet"; pub static WG_DIR: &str = "/etc/wireguard"; @@ -394,62 +65,3 @@ pub fn chmod(file: &File, new_mode: u32) -> Result { Ok(updated) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_peer_no_diff() { - const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE="; - let ip: IpAddr = "10.0.0.1".parse().unwrap(); - let peer = Peer { - id: 1, - contents: PeerContents { - name: "peer1".to_owned(), - ip, - cidr_id: 1, - public_key: PUBKEY.to_owned(), - endpoint: None, - persistent_keepalive_interval: None, - is_admin: false, - is_disabled: false, - is_redeemed: true, - }, - }; - let builder = - PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32); - - let config = builder.into_peer_config(); - - assert_eq!(peer.diff(&config), None); - } - - #[test] - fn test_peer_diff() { - const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE="; - let ip: IpAddr = "10.0.0.1".parse().unwrap(); - let peer = Peer { - id: 1, - contents: PeerContents { - name: "peer1".to_owned(), - ip, - cidr_id: 1, - public_key: PUBKEY.to_owned(), - endpoint: None, - persistent_keepalive_interval: Some(15), - is_admin: false, - is_disabled: false, - is_redeemed: true, - }, - }; - let builder = - PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32); - - let config = builder.into_peer_config(); - - println!("{:?}", peer); - println!("{:?}", config); - assert!(matches!(peer.diff(&config), Some(_))); - } -} diff --git a/shared/src/types.rs b/shared/src/types.rs new file mode 100644 index 0000000..72e6526 --- /dev/null +++ b/shared/src/types.rs @@ -0,0 +1,396 @@ +use ipnetwork::IpNetwork; +use crate::prompts::hostname_validator; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Display, Formatter}, + net::{IpAddr, SocketAddr}, + ops::Deref, + path::Path, + str::FromStr, +}; +use wgctrl::{InterfaceName, InvalidInterfaceName, Key, PeerConfig, PeerConfigBuilder}; + +#[derive(Debug, Clone)] +pub struct Interface { + name: InterfaceName, +} + +impl FromStr for Interface { + type Err = String; + + fn from_str(name: &str) -> Result { + let name = name.to_string(); + hostname_validator(&name)?; + let name = name + .parse() + .map_err(|e: InvalidInterfaceName| e.to_string())?; + Ok(Self { name }) + } +} + +impl Deref for Interface { + type Target = InterfaceName; + + fn deref(&self) -> &Self::Target { + &self.name + } +} + +#[derive(Deserialize, Serialize, Debug)] +#[serde(tag = "option", content = "content")] +pub enum EndpointContents { + Set(SocketAddr), + Unset, +} + +impl Into> for EndpointContents { + fn into(self) -> Option { + match self { + Self::Set(addr) => Some(addr), + Self::Unset => None, + } + } +} + +impl From> for EndpointContents { + fn from(option: Option) -> Self { + match option { + Some(addr) => Self::Set(addr), + None => Self::Unset, + } + } +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct AssociationContents { + pub cidr_id_1: i64, + pub cidr_id_2: i64, +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct Association { + pub id: i64, + + #[serde(flatten)] + pub contents: AssociationContents, +} + +impl Deref for Association { + type Target = AssociationContents; + + fn deref(&self) -> &Self::Target { + &self.contents + } +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct CidrContents { + pub name: String, + pub cidr: IpNetwork, + pub parent: Option, +} + +impl Deref for CidrContents { + type Target = IpNetwork; + + fn deref(&self) -> &Self::Target { + &self.cidr + } +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct Cidr { + pub id: i64, + + #[serde(flatten)] + pub contents: CidrContents, +} + +impl Deref for Cidr { + type Target = CidrContents; + + fn deref(&self) -> &Self::Target { + &self.contents + } +} + +pub struct CidrTree<'a> { + cidrs: &'a [Cidr], + contents: &'a Cidr, +} + +impl<'a> std::ops::Deref for CidrTree<'a> { + type Target = Cidr; + + fn deref(&self) -> &Self::Target { + self.contents + } +} + +impl<'a> CidrTree<'a> { + pub fn new(cidrs: &'a [Cidr]) -> Self { + let root = cidrs + .iter() + .min_by_key(|c| c.cidr.prefix()) + .expect("failed to find root CIDR"); + Self { + cidrs, + contents: root, + } + } + + pub fn children(&self) -> impl Iterator { + self.cidrs + .iter() + .filter(move |c| c.parent == Some(self.contents.id)) + .map(move |c| Self { + cidrs: self.cidrs, + contents: c, + }) + } + + pub fn leaves(&self) -> Vec { + let mut leaves = vec![]; + for cidr in self.cidrs { + if !self.cidrs.iter().any(|c| c.parent == Some(cidr.id)) { + leaves.push(cidr.clone()); + } + } + leaves + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct RedeemContents { + pub public_key: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct PeerContents { + pub name: String, + pub ip: IpAddr, + pub cidr_id: i64, + pub public_key: String, + pub endpoint: Option, + pub persistent_keepalive_interval: Option, + pub is_admin: bool, + pub is_disabled: bool, + pub is_redeemed: bool, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct Peer { + pub id: i64, + + #[serde(flatten)] + pub contents: PeerContents, +} + +impl Deref for Peer { + type Target = PeerContents; + + fn deref(&self) -> &Self::Target { + &self.contents + } +} + +impl Display for Peer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} ({})", &self.name, &self.public_key) + } +} + +#[derive(Debug, PartialEq)] +pub struct PeerDiff { + pub public_key: String, + pub endpoint: Option, + pub persistent_keepalive_interval: Option, + pub is_disabled: bool, +} + +impl Peer { + pub fn diff(&self, peer: &PeerConfig) -> Option { + assert_eq!(self.public_key, peer.public_key.to_base64()); + + let endpoint_diff = if peer.endpoint != self.endpoint { + self.endpoint + } else { + None + }; + + let keepalive_diff = + if peer.persistent_keepalive_interval != self.persistent_keepalive_interval { + self.persistent_keepalive_interval + } else { + None + }; + + if endpoint_diff.is_none() && keepalive_diff.is_none() { + None + } else { + Some(PeerDiff { + public_key: self.public_key.clone(), + endpoint: endpoint_diff, + persistent_keepalive_interval: keepalive_diff, + is_disabled: self.is_disabled, + }) + } + } +} + +impl<'a> From<&'a Peer> for PeerConfigBuilder { + fn from(peer: &Peer) -> Self { + let builder = PeerConfigBuilder::new(&Key::from_base64(&peer.public_key).unwrap()) + .replace_allowed_ips() + .add_allowed_ip(peer.ip, if peer.ip.is_ipv4() { 32 } else { 128 }); + + let builder = if peer.is_disabled { + builder.remove() + } else { + builder + }; + + let builder = if let Some(interval) = peer.persistent_keepalive_interval { + builder.set_persistent_keepalive_interval(interval) + } else { + builder + }; + + if let Some(endpoint) = peer.endpoint { + builder.set_endpoint(endpoint) + } else { + builder + } + } +} + +impl<'a> From<&'a PeerDiff> for PeerConfigBuilder { + fn from(peer: &PeerDiff) -> Self { + let builder = PeerConfigBuilder::new(&Key::from_base64(&peer.public_key).unwrap()); + + let builder = if peer.is_disabled { + builder.remove() + } else { + builder + }; + + let builder = if let Some(interval) = peer.persistent_keepalive_interval { + builder.set_persistent_keepalive_interval(interval) + } else { + builder + }; + + if let Some(endpoint) = peer.endpoint { + builder.set_endpoint(endpoint) + } else { + builder + } + } +} + +/// This model is sent as a response to the /state endpoint, and is meant +/// to include all the data a client needs to update its WireGuard interface. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct State { + /// This list will be only the peers visible to the user requesting this + /// information, not including disabled peers or peers from other CIDRs + /// that the user's CIDR is not authorized to communicate with. + pub peers: Vec, + + /// At the moment, this is all CIDRs, regardless of whether the peer is + /// eligible to communicate with them or not. + pub cidrs: Vec, +} + +pub trait IoErrorContext { + fn with_path>(self, path: P) -> Result; + fn with_str>(self, context: S) -> Result; +} + +impl IoErrorContext for Result { + fn with_path>(self, path: P) -> Result { + self.with_str(path.as_ref().to_string_lossy()) + } + + fn with_str>(self, context: S) -> Result { + self.map_err(|e| WrappedIoError { + io_error: e, + context: context.into(), + }) + } +} + +#[derive(Debug)] +pub struct WrappedIoError { + io_error: std::io::Error, + context: String, +} + +impl std::fmt::Display for WrappedIoError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "{} - {}", self.context, self.io_error) + } +} + +impl std::error::Error for WrappedIoError {} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::IpAddr; + use wgctrl::{Key, PeerConfigBuilder}; + + #[test] + fn test_peer_no_diff() { + const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE="; + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + let peer = Peer { + id: 1, + contents: PeerContents { + name: "peer1".to_owned(), + ip, + cidr_id: 1, + public_key: PUBKEY.to_owned(), + endpoint: None, + persistent_keepalive_interval: None, + is_admin: false, + is_disabled: false, + is_redeemed: true, + }, + }; + let builder = + PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32); + + let config = builder.into_peer_config(); + + assert_eq!(peer.diff(&config), None); + } + + #[test] + fn test_peer_diff() { + const PUBKEY: &str = "4CNZorWVtohO64n6AAaH/JyFjIIgBFrfJK2SGtKjzEE="; + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + let peer = Peer { + id: 1, + contents: PeerContents { + name: "peer1".to_owned(), + ip, + cidr_id: 1, + public_key: PUBKEY.to_owned(), + endpoint: None, + persistent_keepalive_interval: Some(15), + is_admin: false, + is_disabled: false, + is_redeemed: true, + }, + }; + let builder = + PeerConfigBuilder::new(&Key::from_base64(PUBKEY).unwrap()).add_allowed_ip(ip, 32); + + let config = builder.into_peer_config(); + + println!("{:?}", peer); + println!("{:?}", config); + assert!(matches!(peer.diff(&config), Some(_))); + } +} \ No newline at end of file