From b6ce16bc001894dc0f0e5815db03f7cafa76d87b Mon Sep 17 00:00:00 2001 From: Jake McGinty Date: Tue, 1 Feb 2022 13:53:31 +0900 Subject: [PATCH] server: add better validation to the associations endpoint (#194) --- server/src/db/association.rs | 124 +++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/server/src/db/association.rs b/server/src/db/association.rs index d6eb24f..cca9ebb 100644 --- a/server/src/db/association.rs +++ b/server/src/db/association.rs @@ -57,6 +57,30 @@ impl DatabaseAssociation { cidr_id_2, } = &contents; + // Verify an existing association doesn't currently exist + let existing_associations: usize = conn.query_row( + "SELECT COUNT(*) + FROM associations + WHERE (cidr_id_1 = ?1 AND cidr_id_2 = ?2) OR (cidr_id_1 = ?2 AND cidr_id_2 = ?1)", + params![cidr_id_1, cidr_id_2], + |r| r.get(0), + )?; + if existing_associations > 0 { + return Err(ServerError::InvalidQuery); + } + + // Verify both provided CIDR IDs exist + let existing_cidrs: usize = conn.query_row( + "SELECT COUNT(*) + FROM cidrs + WHERE id = ?1 OR id = ?2", + params![cidr_id_1, cidr_id_2], + |r| r.get(0), + )?; + if existing_cidrs != 2 { + return Err(ServerError::InvalidQuery); + } + conn.execute( "INSERT INTO associations (cidr_id_1, cidr_id_2) VALUES (?1, ?2)", @@ -89,3 +113,103 @@ impl DatabaseAssociation { Ok(auth_iter.collect::, rusqlite::Error>>()?) } } + +#[cfg(test)] +mod tests { + use crate::test; + use shared::{CidrContents, Error}; + + use super::*; + + #[tokio::test] + async fn test_double_add() -> Result<(), Error> { + let server = test::Server::new()?; + + let contents = AssociationContents { + cidr_id_1: 1, + cidr_id_2: 2, + }; + let contents_flipped = AssociationContents { + cidr_id_1: 2, + cidr_id_2: 1, + }; + let res = server + .form_request( + test::ADMIN_PEER_IP, + "POST", + "/v1/admin/associations", + &contents, + ) + .await; + assert!(res.status().is_success()); + + let res = server + .form_request( + test::ADMIN_PEER_IP, + "POST", + "/v1/admin/associations", + &contents, + ) + .await; + assert!(res.status().is_client_error()); + + let res = server + .form_request( + test::ADMIN_PEER_IP, + "POST", + "/v1/admin/associations", + &contents_flipped, + ) + .await; + assert!(res.status().is_client_error()); + Ok(()) + } + + #[tokio::test] + async fn test_nonexistent_cidr_id() -> Result<(), Error> { + let server = test::Server::new()?; + + // Verify both provided CIDR IDs exist + let last_cidr_id: i64 = + server + .db() + .lock() + .query_row("SELECT COUNT(*) FROM cidrs", params![], |r| r.get(0))?; + let contents = AssociationContents { + cidr_id_1: 1, + cidr_id_2: last_cidr_id + 1, + }; + let res = server + .form_request( + test::ADMIN_PEER_IP, + "POST", + "/v1/admin/associations", + &contents, + ) + .await; + assert!(!res.status().is_success()); + + let cidr = CidrContents { + name: "experimental".to_string(), + cidr: test::EXPERIMENTAL_CIDR.parse()?, + parent: Some(test::ROOT_CIDR_ID), + }; + + let res = server + .form_request(test::ADMIN_PEER_IP, "POST", "/v1/admin/cidrs", &cidr) + .await; + assert!(res.status().is_success()); + + let res = server + .form_request( + test::ADMIN_PEER_IP, + "POST", + "/v1/admin/associations", + &contents, + ) + .await; + assert!(res.status().is_success()); + + Ok(()) + } +}