diff --git a/Cargo.lock b/Cargo.lock index 15d3f01..46330a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,7 +358,11 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hostsfile" -version = "1.1.0" +version = "1.2.0" +dependencies = [ + "log", + "tempfile", +] [[package]] name = "http" diff --git a/hostsfile/Cargo.toml b/hostsfile/Cargo.toml index d95b4cc..77526c0 100644 --- a/hostsfile/Cargo.toml +++ b/hostsfile/Cargo.toml @@ -2,9 +2,13 @@ authors = ["Ryo Kawaguchi "] description = "A simplistic /etc/hosts file editor." edition = "2021" -license = "UNLICENSED" +license = "MIT" name = "hostsfile" publish = false -version = "1.1.0" +version = "1.2.0" [dependencies] +log = "0.4" + +[dev-dependencies] +tempfile = "3" diff --git a/hostsfile/src/lib.rs b/hostsfile/src/lib.rs index a2c2688..e19db3b 100644 --- a/hostsfile/src/lib.rs +++ b/hostsfile/src/lib.rs @@ -6,6 +6,7 @@ use std::{ net::IpAddr, path::{Path, PathBuf}, result, + time::{SystemTime, UNIX_EPOCH}, }; pub type Result = result::Result>; @@ -149,13 +150,46 @@ impl HostsBuilder { Ok(hosts_file) } + pub fn get_temp_path(hosts_path: &Path) -> io::Result { + let hosts_dir = hosts_path.parent().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "hosts path missing a parent folder", + ) + })?; + let start = SystemTime::now(); + let since_the_epoch = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + let mut temp_filename = hosts_path + .file_name() + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "hosts path missing a filename") + })? + .to_os_string(); + temp_filename.push(format!(".tmp{}", since_the_epoch.as_millis())); + Ok(hosts_dir.with_file_name(temp_filename)) + } + /// Inserts a new section to the specified hosts file. If there is a section with the same tag /// name already, it will be replaced with the new list instead. /// + /// `hosts_path` is the *full* path to write to, including the filename. + /// /// On Windows, the format of one hostname per line will be used, all other systems will use /// the same format as Unix and Unix-like systems (i.e. allow multiple hostnames per line). pub fn write_to>(&self, hosts_path: P) -> io::Result<()> { let hosts_path = hosts_path.as_ref(); + if hosts_path.is_dir() { + // TODO(jake): use io::ErrorKind::IsADirectory when it's stable. + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "hosts path was a directory", + )); + } + + let temp_path = Self::get_temp_path(hosts_path)?; + let begin_marker = format!("# DO NOT EDIT {} BEGIN", &self.tag); let end_marker = format!("# DO NOT EDIT {} END", &self.tag); @@ -218,14 +252,74 @@ impl HostsBuilder { writeln!(&mut s, "{}", line)?; } + match Self::write_and_swap(&temp_path, hosts_path, &s) { + Err(_) => { + Self::write_clobber(hosts_path, &s)?; + log::debug!("wrote hosts file with the clobber fallback strategy"); + }, + _ => { + log::debug!("wrote hosts file with the write-and-swap strategy"); + }, + } + Ok(()) + } + + fn write_and_swap(temp_path: &Path, hosts_path: &Path, contents: &[u8]) -> io::Result<()> { + // Copy the file we plan on modifying so its permissions and metadata are preserved. + std::fs::copy(&hosts_path, &temp_path)?; + Self::write_clobber(temp_path, contents)?; + std::fs::rename(temp_path, hosts_path)?; + Ok(()) + } + + fn write_clobber(hosts_path: &Path, contents: &[u8]) -> io::Result<()> { OpenOptions::new() .create(true) .read(true) .write(true) .truncate(true) .open(hosts_path)? - .write_all(&s)?; - + .write_all(contents)?; Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::path::Path; + + #[test] + fn test_temp_path_good() { + let hosts_path = Path::new("/etc/hosts"); + let temp_path = HostsBuilder::get_temp_path(hosts_path).unwrap(); + println!("{:?}", temp_path); + assert!(temp_path + .file_name() + .unwrap() + .to_str() + .unwrap() + .starts_with("hosts.tmp")); + } + + #[test] + fn test_temp_path_invalid() { + let hosts_path = Path::new("/"); + assert!(HostsBuilder::get_temp_path(hosts_path).is_err()); + } + + #[test] + fn test_write() { + let (mut temp_file, temp_path) = tempfile::NamedTempFile::new().unwrap().into_parts(); + temp_file.write_all(b"preexisting\ncontent").unwrap(); + let mut builder = HostsBuilder::new("foo"); + builder.add_hostname([1, 1, 1, 1].into(), "whatever"); + builder.write_to(&temp_path).unwrap(); + + let contents = std::fs::read_to_string(&temp_path).unwrap(); + println!("contents: {}", contents); + assert!(contents.starts_with("preexisting\ncontent")); + assert!(contents.contains("# DO NOT EDIT foo BEGIN")); + assert!(contents.contains("1.1.1.1 whatever")); + } +}