diff --git a/Cargo.toml b/Cargo.toml index 63b3950..36b5f9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,3 +29,4 @@ url = "2.5.2" [dev-dependencies] async-std = { version = "1.13.0", features = ["attributes"], default-features = false } rstest = "0.23.0" +tempfile = "3.14.0" diff --git a/src/numtracker.rs b/src/numtracker.rs index a0f1de8..d58dec1 100644 --- a/src/numtracker.rs +++ b/src/numtracker.rs @@ -30,10 +30,10 @@ pub struct NumTracker { impl NumTracker { /// Build a numtracker than will provide locked access to subdirectories that exists and no-op /// trackers for beamlines that do not have subdirectories. - pub fn for_root_directory(root: Option) -> Result { + pub fn for_root_directory>(root: Option

) -> Result { let mut bl_locks: HashMap> = Default::default(); if let Some(dir) = root { - for entry in dir.read_dir()? { + for entry in dir.as_ref().read_dir()? { let dir = entry?; if dir.file_type()?.is_dir() { if let Ok(name) = dir.file_name().into_string() { @@ -165,3 +165,171 @@ impl Display for InvalidExtension { } impl std::error::Error for InvalidExtension {} + +#[cfg(test)] +mod tests { + use std::fs; + use std::ops::Deref; + use std::time::Duration; + + use rstest::{fixture, rstest}; + use tempfile::{tempdir, TempDir}; + use tokio::time::timeout; + + use super::{InvalidExtension, NumTracker}; + + /// Wrapper around a NumTracker to ensure the tempdir is not dropped while it is still required + struct TempTracker(NumTracker, TempDir); + impl Deref for TempTracker { + type Target = NumTracker; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + #[rstest::fixture] + fn root() -> TempDir { + let root = tempdir().unwrap(); + + fs::create_dir(root.as_ref().join("i22")).unwrap(); + fs::File::create(root.as_ref().join("i22").join("122.i22")).unwrap(); + + fs::create_dir(root.as_ref().join("b21")).unwrap(); + + root + } + + #[fixture] + fn nt(root: TempDir) -> TempTracker { + TempTracker(NumTracker::for_root_directory(Some(&root)).unwrap(), root) + } + + #[rstest] + #[tokio::test[]] + async fn exclusive_locking(nt: TempTracker) { + let i22 = nt.for_beamline("i22", None).await; + + // difficult to test but this should be locked until i22 is dropped + nt.bl_locks.get("i22").unwrap().try_lock().unwrap_err(); + nt.bl_locks.get("i22").unwrap().try_lock().unwrap_err(); + nt.bl_locks.get("i22").unwrap().try_lock().unwrap_err(); + + drop(i22); + // lock should now be free + _ = nt.bl_locks.get("i22").unwrap().try_lock().unwrap(); + } + + #[rstest] + #[tokio::test] + async fn multiple_beamlines_not_exclusive(nt: TempTracker) { + // trackers for different beamlines can be held concurrently + let _i22 = nt.for_beamline("i22", None).await.unwrap(); + let _b21 = nt.for_beamline("b21", None).await.unwrap(); + } + + #[rstest] + #[tokio::test] + async fn unmanaged_beamlines_not_locked(nt: TempTracker) { + let i11 = nt.for_beamline("i11", None); + let i11_2 = nt.for_beamline("i11", None); + let i11_3 = nt.for_beamline("i11", None); + let i11_4 = nt.for_beamline("i11", None); + + // This should never get near 1s but in case something deadlocks we want to exit early. The + // test will still fail successfully in this case. + timeout(Duration::from_secs(1), async { + i11.await.unwrap(); + i11_2.await.unwrap(); + i11_3.await.unwrap(); + i11_4.await.unwrap(); + }) + .await + .expect("Timed out waiting for unmanaged trackers"); + } + + #[rstest] + #[tokio::test] + async fn unmanaged_beamline_has_no_numbers(nt: TempTracker) { + let i11 = nt.for_beamline("i11", None).await.unwrap(); + if let Some(num) = i11.prev().await.unwrap() { + panic!("Unmanaged beamline returned previous number: {num}"); + } + // setting an unmanaged beamline is a no-op + i11.set(111).await.unwrap(); + if let Some(num) = i11.prev().await.unwrap() { + panic!("Unmanaged beamline returned previous number: {num}"); + } + } + + #[rstest] + #[tokio::test] + async fn bump_numbers(nt: TempTracker) { + let i22 = nt.for_beamline("i22", None).await.unwrap(); + assert_eq!(i22.prev().await.unwrap(), Some(122)); + i22.set(123).await.unwrap(); + assert_eq!(i22.prev().await.unwrap(), Some(123)); + assert!( + !fs::exists(nt.1.as_ref().join("i22").join("122.i22")).unwrap(), + "previous number file not deleted" + ); + } + + #[rstest] + #[tokio::test] + async fn non_consecutive_files_left(nt: TempTracker) { + let i22 = nt.for_beamline("i22", None).await.unwrap(); + assert_eq!(i22.prev().await.unwrap(), Some(122)); + i22.set(244).await.unwrap(); + assert_eq!(i22.prev().await.unwrap(), Some(244)); + assert!( + fs::exists(nt.1.as_ref().join("i22").join("122.i22")).unwrap(), + "Non-consecutive previous file was removed" + ); + } + + #[rstest] + #[tokio::test] + async fn alternative_extensions(nt: TempTracker) { + let i22 = nt.for_beamline("i22", None).await.unwrap(); // default i22 extension + assert_eq!(i22.prev().await.unwrap(), Some(122)); + drop(i22); + let i22 = nt.for_beamline("i22", Some("alt")).await.unwrap(); + assert_eq!(i22.prev().await.unwrap(), Some(0)); + i22.set(1234).await.unwrap(); + assert!( + fs::exists(nt.1.as_ref().join("i22").join("122.i22")).unwrap(), + "Existing extension file was removed" + ); + assert!( + fs::exists(nt.1.as_ref().join("i22").join("1234.alt")).unwrap(), + "New alternative extension file was not created" + ); + drop(i22); + } + + #[rstest] + #[tokio::test] + async fn invalid_extensions(nt: TempTracker) { + let Err(InvalidExtension) = nt.for_beamline("i22", Some("ext space")).await else { + panic!("Invalid extension was accepted"); + }; + + let Err(InvalidExtension) = nt.for_beamline("i22", Some("in:valid@chars")).await else { + panic!("Invalid extension was accepted"); + }; + + let Err(InvalidExtension) = nt.for_beamline("i22", Some("i22/../beamline")).await else { + panic!("Invalid extension was accepted"); + }; + assert_eq!(InvalidExtension.to_string(), "Extension is not valid"); + } + + #[rstest] + #[tokio::test] + async fn non_number_files(nt: TempTracker) { + fs::File::create(nt.1.as_ref().join("i22").join("string.i22")).unwrap(); + let i22 = nt.for_beamline("i22", None).await.unwrap(); + assert_eq!(i22.prev().await.unwrap(), Some(122)); + } +}