diff --git a/Cargo.lock b/Cargo.lock index 6aed790..dede564 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,6 +56,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "hex-literal" version = "0.4.1" @@ -64,9 +75,45 @@ checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.146" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] [[package]] name = "sha2" @@ -81,9 +128,10 @@ dependencies = [ [[package]] name = "stream_limiter" -version = "1.2.0" +version = "2.0.0" dependencies = [ "hex-literal", + "rand", "sha2", ] @@ -98,3 +146,9 @@ name = "version_check" version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" diff --git a/Cargo.toml b/Cargo.toml index 471036a..32eba26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,7 @@ keywords = ["rate", "rate_limiting", "synchronous", "stream"] [dev-dependencies] sha2 = "0.10.6" hex-literal = "0.4.1" +rand = { version = "0.8.5", features = ["small_rng"] } + +[features] +heavy_testing = [] diff --git a/src/lib.rs b/src/lib.rs index 2c0c4f6..43bdfe4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,18 +15,60 @@ //! assert_eq!(now.elapsed().as_secs(), 9); //! ``` use std::{ + debug_assert, io::{self, Read, Write}, time::Duration, }; +const ACCEPTABLE_SPEED_DIFF: f64 = 4.0 / 100.0; + +#[derive(Clone, Debug)] pub struct LimiterOptions { - window_length: u128, - window_time: Duration, - bucket_size: usize, + pub window_length: u64, + pub window_time: Duration, + pub bucket_size: u64, } impl LimiterOptions { - pub fn new(window_length: u128, window_time: Duration, bucket_size: usize) -> LimiterOptions { + pub fn new( + mut window_length: u64, + mut window_time: Duration, + mut bucket_size: u64, + ) -> LimiterOptions { + let rate = window_length.min(bucket_size) as f64; + let tw: f64 = window_time.as_nanos() as f64; + let init_speed = (rate / tw) * 1_000_000.0; + + let mut new_speed = init_speed; + let mut new_wlen = window_length; + let mut new_wtime = window_time; + let mut new_bsize = bucket_size; + + // While the difference between the intented speed (init_speed) and the reduced one (new_speed) is under the threshold + // Each iteration, divide all the options by 2, and recompute the speed (in order to check if it's not altered) + // Because we want the values BEFORE the speed is above the threshold, assign the new values on start of the new iter only + while ((new_speed / init_speed) - 1.0).abs() < ACCEPTABLE_SPEED_DIFF { + // Values from past iter, we know they're under the threshold + window_length = new_wlen; + window_time = new_wtime; + bucket_size = new_bsize; + + // If values aren't dividable by 2 + if (new_wlen == 1) || (new_bsize == 1) || (new_wtime.as_nanos() == 1) { + break; + } + + // Reduce the options + new_wlen /= 2; + new_wtime /= 2; + new_bsize /= 2; + + // Recompute the new speed + let rate = new_wlen.min(new_bsize) as f64; + let tw: f64 = new_wtime.as_nanos() as f64; + new_speed = (rate / tw) * 1_000_000.0; + } + LimiterOptions { window_length, window_time, @@ -46,6 +88,7 @@ where write_opt: Option, last_read_check: Option, last_write_check: Option, + additionnal_tokens: (u64, u64), } impl Limiter @@ -77,17 +120,18 @@ where }, read_opt, write_opt, + additionnal_tokens: (0, 0), } } - fn stream_cap_limit(&self) -> (Option, Option) { + fn stream_cap_limit(&self) -> (Option, Option) { let read_cap = if let Some(LimiterOptions { window_length, bucket_size, .. }) = self.read_opt { - Some(std::cmp::min(window_length as usize, bucket_size)) + Some(std::cmp::min(window_length, bucket_size)) } else { None }; @@ -97,25 +141,33 @@ where .. }) = self.write_opt { - Some(std::cmp::min(window_length as usize, bucket_size)) + Some(std::cmp::min(window_length, bucket_size)) } else { None }; (read_cap, write_cap) } - fn tokens_available(&self) -> (Option, Option) { + fn tokens_available(&self) -> (Option, Option) { let read_tokens = if let Some(LimiterOptions { window_length, window_time, bucket_size, }) = self.read_opt { - Some(std::cmp::min( - ((self.last_read_check.unwrap().elapsed().as_nanos() / window_time.as_nanos()) - * window_length) as usize, - bucket_size, - )) + let lrc = match u64::try_from(self.last_read_check.unwrap().elapsed().as_nanos()) { + Ok(n) => n, + // Will cap the last_read_check at a duration of about 584 years + Err(_) => u64::MAX, + }; + Some( + std::cmp::min( + (lrc / u64::try_from(window_time.as_nanos()) + .expect("Window time nanos > u64::MAX")) + * window_length, + bucket_size, + ) + self.additionnal_tokens.0, + ) } else { None }; @@ -125,11 +177,19 @@ where bucket_size, }) = self.write_opt { - Some(std::cmp::min( - ((self.last_write_check.unwrap().elapsed().as_nanos() / window_time.as_nanos()) - * window_length) as usize, - bucket_size, - )) + let lwc = match u64::try_from(self.last_write_check.unwrap().elapsed().as_nanos()) { + Ok(n) => n, + // Will cap the last_read_check at a duration of about 584 years + Err(_) => u64::MAX, + }; + Some( + std::cmp::min( + (lwc / u64::try_from(window_time.as_nanos()) + .expect("Window time nanos > u64::MAX")) + * window_length, + bucket_size, + ) + self.additionnal_tokens.1, + ) } else { None }; @@ -152,7 +212,7 @@ where /// If you didn't read for 10 secondes in this stream and you try to read 10 bytes, it will read instantly. fn read(&mut self, buf: &mut [u8]) -> io::Result { let mut read = 0; - let mut buf_left = buf.len(); + let mut buf_left = u64::try_from(buf.len()).expect("R buflen to u64"); let readlimit = if let (Some(limit), _) = self.stream_cap_limit() { limit } else { @@ -167,22 +227,31 @@ where } else { Duration::ZERO }; - self.last_read_check = Some(std::time::Instant::now()); std::thread::sleep(window_time.saturating_sub(elapsed)); + debug_assert!(self.tokens_available().0.unwrap() > 0); continue; } // Before reading so that we don't count the time it takes to read self.last_read_check = Some(std::time::Instant::now()); - let buf_read_end = read + nb_bytes_readable.min(buf_left); - let read_now = self.stream.read(&mut buf[read..buf_read_end])?; - if read_now < nb_bytes_readable { + let read_start = usize::try_from(read).expect("R read_start to usize"); + let read_end = usize::try_from(read + nb_bytes_readable.min(buf_left)) + .expect("R read_end to usize"); + let read_now = u64::try_from(self.stream.read(&mut buf[read_start..read_end])?) + .expect("R read_now to u64"); + if read_now == 0 { break; } + if read_now < nb_bytes_readable { + self.additionnal_tokens.0 = self + .additionnal_tokens + .0 + .saturating_add(nb_bytes_readable - read_now); + } read += read_now; buf_left -= read_now; } self.last_read_check = Some(std::time::Instant::now()); - Ok(read) + Ok(usize::try_from(read).expect("R return to usize")) } } @@ -194,7 +263,7 @@ where /// If you didn't write for 10 secondes in this stream and you try to write 10 bytes, it will write instantly. fn write(&mut self, buf: &[u8]) -> io::Result { let mut write = 0; - let mut buf_left = buf.len(); + let mut buf_left = u64::try_from(buf.len()).expect("W buflen to u64"); let writelimit = if let (_, Some(limit)) = self.stream_cap_limit() { limit } else { @@ -209,22 +278,31 @@ where } else { Duration::ZERO }; - self.last_write_check = Some(std::time::Instant::now()); std::thread::sleep(window_time.saturating_sub(elapsed)); + debug_assert!(self.tokens_available().1.unwrap() > 0); continue; } // Before reading so that we don't count the time it takes to read self.last_write_check = Some(std::time::Instant::now()); - let buf_write_end = write + nb_bytes_writable.min(buf_left); - let write_now = self.stream.write(&buf[write..buf_write_end])?; + let write_start = usize::try_from(write).expect("W write_start to usize"); + let write_end = usize::try_from(write + nb_bytes_writable.min(buf_left)) + .expect("W write_end to usize"); + let write_now = u64::try_from(self.stream.write(&buf[write_start..write_end])?) + .expect("W write_now_ to u64"); if write_now < nb_bytes_writable { break; } + if write_now < nb_bytes_writable { + self.additionnal_tokens.1 = self + .additionnal_tokens + .1 + .saturating_add(nb_bytes_writable - write_now); + } write += write_now; buf_left -= write_now; } self.last_write_check = Some(std::time::Instant::now()); - Ok(write) + Ok(usize::try_from(write).expect("W return to usize")) } fn flush(&mut self) -> io::Result<()> { diff --git a/tests/network.rs b/tests/network.rs new file mode 100644 index 0000000..40a4a09 --- /dev/null +++ b/tests/network.rs @@ -0,0 +1,143 @@ +mod utils; + +mod tests { + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream}; + use std::time::Duration; + use stream_limiter::{Limiter, LimiterOptions}; + + #[test] + fn test_limit_read() { + let listener = TcpListener::bind("127.0.0.1:34254").unwrap(); + std::thread::spawn(|| { + println!("W] Connecting..."); + let mut stream = TcpStream::connect("127.0.0.1:34254").unwrap(); + println!("W] Writing ..."); + stream.write(&[42u8; 10]).unwrap(); + println!("W] OK"); + }); + println!("R] Listening ..."); + for stream in listener.incoming() { + println!("R] Stream {:?} connected", stream); + let mut limiter = Limiter::new( + stream.unwrap(), + Some(LimiterOptions::new(9, Duration::from_secs(1), 10)), + None, + ); + println!("R] Reading with limitation"); + let mut buffer = [0u8; 10]; + let now = std::time::Instant::now(); + limiter.read(&mut buffer).unwrap(); + println!("R] Result: {:?} (len {})", buffer, buffer.len()); + assert_eq!(buffer, [42; 10]); + assert_eq!(now.elapsed().as_secs(), 1); + break; + } + } + + #[test] + fn test_limit_write() { + let listener = TcpListener::bind("127.0.0.1:34255").unwrap(); + std::thread::spawn(|| { + println!("W] Connecting..."); + let stream = TcpStream::connect("127.0.0.1:34255").unwrap(); + let mut limiter = Limiter::new( + stream, + None, + Some(LimiterOptions::new(9, Duration::from_secs(1), 10)), + ); + println!("W] Writing ..."); + limiter.write(&[42u8; 10]).unwrap(); + println!("W] OK"); + }); + println!("R] Listening ..."); + for stream in listener.incoming() { + let mut stream = stream.unwrap(); + println!("R] Stream {:?} connected", stream); + println!("R] Reading with limitation"); + let mut buffer = [0; 10]; + let now = std::time::Instant::now(); + stream.read_exact(&mut buffer).unwrap(); + + println!("R] Result: {:?} (len {})", buffer, buffer.len()); + assert_eq!(buffer, [42; 10]); + assert_eq!(now.elapsed().as_secs(), 1); + break; + } + } + + #[test] + fn test_limit_both() { + let listener = TcpListener::bind("127.0.0.1:34256").unwrap(); + std::thread::spawn(|| { + println!("W] Connecting..."); + let stream = TcpStream::connect("127.0.0.1:34256").unwrap(); + let mut limiter = Limiter::new( + stream, + None, + Some(LimiterOptions::new(9, Duration::from_secs(1), 10)), + ); + println!("W] Writing ..."); + limiter.write_all(&[42u8; 10]).unwrap(); + println!("W] OK"); + }); + println!("R] Listening ..."); + for stream in listener.incoming() { + println!("R] Stream {:?} connected", stream); + println!("R] Reading with limitation"); + let mut buffer = [0; 10]; + let now = std::time::Instant::now(); + let mut limiter = Limiter::new( + stream.unwrap(), + Some(LimiterOptions::new(9, Duration::from_secs(1), 10)), + None, + ); + limiter.read_exact(&mut buffer).unwrap(); + + println!("R] Result: {:?} (len {})", buffer, buffer.len()); + assert_eq!(buffer, [42; 10]); + assert_eq!(now.elapsed().as_secs(), 1); + break; + } + } + + #[test] + fn test_peak_both() { + let listener = TcpListener::bind("127.0.0.1:34257").unwrap(); + std::thread::spawn(|| { + println!("W] Connecting..."); + let stream = TcpStream::connect("127.0.0.1:34257").unwrap(); + let mut limiter = Limiter::new( + stream, + None, + Some(LimiterOptions::new(10, Duration::from_secs(1), 10)), + ); + println!("W] Writing ..."); + limiter.write_all(&[42u8; 10]).unwrap(); + println!("W] OK"); + }); + println!("R] Listening ..."); + for stream in listener.incoming() { + println!("R] Stream {:?} connected", stream); + println!("R] Reading with limitation"); + let mut buffer = [0; 10]; + let now = std::time::Instant::now(); + let mut limiter = Limiter::new( + stream.unwrap(), + Some(LimiterOptions::new(10, Duration::from_secs(1), 10)), + None, + ); + limiter.read_exact(&mut buffer).unwrap(); + + println!( + "R] Result: {:?} (len {}), in {:?}", + buffer, + buffer.len(), + now.elapsed() + ); + assert_eq!(now.elapsed().as_secs(), 0); + assert_eq!(buffer, [42; 10]); + break; + } + } +} diff --git a/tests/parametric.rs b/tests/parametric.rs new file mode 100644 index 0000000..ac825c1 --- /dev/null +++ b/tests/parametric.rs @@ -0,0 +1,208 @@ +mod utils; + +mod tests { + use crate::utils::paramtests::start_parametric_test; + use sha2::Digest; + use std::{ + io::{Read, Write}, + time::Duration, + }; + use stream_limiter::{Limiter, LimiterOptions}; + + const INSTANT_IO_EPS: u128 = 1_500_000; + const DATALEN_DIVIDER: usize = 5; + + fn get_data_hash(data: &Vec) -> [u8; 32] { + let mut hasher = sha2::Sha256::new(); + hasher.update(data); + hasher.finalize().into() + } + + fn get_random_options(rng: &mut R, datalen: usize) -> Option { + if rng.gen_bool(0.08) { + None + } else { + Some(LimiterOptions::new( + rng.gen_range((datalen / DATALEN_DIVIDER)..(datalen * DATALEN_DIVIDER)) as u64, + Duration::from_millis( + rng.gen_range(DATALEN_DIVIDER..(1000 / DATALEN_DIVIDER)) as u64 + ), + rng.gen_range((datalen / DATALEN_DIVIDER)..(datalen * DATALEN_DIVIDER)) as u64, + )) + } + } + + #[test] + fn test_buffer() { + fn paramtest_buffer(mut rng: R) { + let datalen = rng.gen_range(10..1024 * 512); + + let outbuf = std::io::Cursor::new(vec![]); + let wopts: Option = get_random_options(&mut rng, datalen); + let ropts = get_random_options(&mut rng, datalen); + + let data: Vec = (0..datalen).map(|_| rng.gen::()).collect(); + let buf = data.clone(); + let data_checksum = get_data_hash(&buf); + + let mut limiter = Limiter::new(outbuf, ropts.clone(), wopts.clone()); + let now = std::time::Instant::now(); + let nwrite = limiter.write(&buf).unwrap(); + let elapsed = now.elapsed(); + assert_eq!(nwrite, datalen); + if let Some(ref opts) = wopts { + let rate = opts.window_length.min(opts.bucket_size); + if (datalen as u64) > rate { + assert!(elapsed.as_nanos() > opts.window_time.as_nanos()); + } else { + assert!(elapsed.as_nanos() <= opts.window_time.as_nanos()); + } + } else { + assert!(elapsed.as_nanos() < INSTANT_IO_EPS); + } + + assert_eq!(get_data_hash(limiter.stream.get_ref()), data_checksum); + + let read_buf = limiter.stream.into_inner(); + let mut buf = vec![0; datalen]; + let mut limiter = Limiter::new(std::io::Cursor::new(read_buf), ropts.clone(), wopts); + let now = std::time::Instant::now(); + let nread = limiter.read(buf.as_mut_slice()).unwrap(); + let elapsed = now.elapsed(); + assert_eq!(nread, datalen); + if let Some(ref opts) = ropts { + let rate = opts.window_length.min(opts.bucket_size); + if datalen as u64 > rate { + assert!(elapsed.as_nanos() > opts.window_time.as_nanos()); + } else { + assert!(elapsed.as_nanos() <= opts.window_time.as_nanos()); + } + } else { + assert!(elapsed.as_nanos() < INSTANT_IO_EPS); + } + + assert_eq!(get_data_hash(&buf), data_checksum); + assert_eq!(&data, &buf); + } + start_parametric_test( + 100, + vec![14398057406427516238, 13640999559978117227], + paramtest_buffer, + ); + } + + #[test] + fn test_tcp() { + use std::net::{TcpListener, TcpStream}; + + fn paramtest_tcp(mut rng: R) { + let datalen = rng.gen_range(10..1024 * 512); + let wopts_connector: Option = get_random_options(&mut rng, datalen); + let wopts_listener: Option = get_random_options(&mut rng, datalen); + let ropts_connector = get_random_options(&mut rng, datalen); + let ropts_listener = get_random_options(&mut rng, datalen); + let data: Vec = (0..datalen).map(|_| rng.gen::()).collect(); + let data_c = data.clone(); + let datahash = get_data_hash(&data); + let mut port = 10000 + rng.gen_range(0..(u16::MAX - 10000)); + + let listener = loop { + match TcpListener::bind(format!("127.0.0.1:{port}")) { + Ok(l) => break l, + Err(_) => { + port += 1; + } + } + }; + + let connector = std::thread::spawn(move || { + let stream = TcpStream::connect(format!("127.0.0.1:{port}")).unwrap(); + let mut limiter = + Limiter::new(stream, ropts_connector.clone(), wopts_connector.clone()); + + let now = std::time::Instant::now(); + limiter.write_all(&data_c).unwrap(); + let elapsed = now.elapsed(); + if let Some(ref opts) = wopts_connector { + let rate = opts.window_length.min(opts.bucket_size); + if datalen as u64 > rate { + assert!(elapsed.as_nanos() > opts.window_time.as_nanos()); + } else { + assert!(elapsed.as_nanos() <= opts.window_time.as_nanos()); + } + } + + let mut limiter = Limiter::new( + limiter.stream, + ropts_connector.clone(), + wopts_connector.clone(), + ); + let mut buf = vec![0; datalen]; + assert_ne!(get_data_hash(&buf), datahash); + let now = std::time::Instant::now(); + limiter.read_exact(&mut buf).unwrap(); + let elapsed = now.elapsed(); + assert_eq!(get_data_hash(&buf), datahash); + if let Some(ref opts) = ropts_connector { + let rate = opts.window_length.min(opts.bucket_size); + if datalen as u64 > rate { + assert!(elapsed.as_nanos() > opts.window_time.as_nanos()); + } else { + assert!(elapsed.as_nanos() <= opts.window_time.as_nanos()); + } + } + }); + std::thread::sleep(Duration::from_millis(50)); + + for stream in listener.incoming() { + let mut limiter = Limiter::new( + stream.unwrap(), + ropts_listener.clone(), + wopts_listener.clone(), + ); + + let mut buf = vec![0; datalen]; + let now = std::time::Instant::now(); + limiter.read_exact(&mut buf).unwrap(); + let elapsed = now.elapsed(); + assert_eq!(get_data_hash(&buf), datahash); + if let Some(ref opts) = ropts_listener { + let rate = opts.window_length.min(opts.bucket_size); + if datalen as u64 > rate { + assert!(elapsed.as_nanos() > opts.window_time.as_nanos()); + } else { + assert!(elapsed.as_nanos() <= opts.window_time.as_nanos()); + } + } + + let mut limiter = + Limiter::new(limiter.stream, ropts_listener, wopts_listener.clone()); + let now = std::time::Instant::now(); + limiter.write_all(&data).unwrap(); + let elapsed = now.elapsed(); + if let Some(opts) = wopts_listener { + let rate = opts.window_length.min(opts.bucket_size); + if datalen as u64 > rate { + assert!(elapsed.as_nanos() > opts.window_time.as_nanos()); + } else { + assert!(elapsed.as_nanos() <= opts.window_time.as_nanos()); + } + } + break; + } + assert!(connector.join().is_ok()); + } + start_parametric_test( + 100, + vec![ + 15164449282496041257, + 3911014536179701959, + 2770066496784563521, + 16118644738678043134, + 15039019555209573434, + 18348045085902583881, + ], + paramtest_tcp, + ); + } +} diff --git a/tests/read.rs b/tests/read.rs index 73fe780..71f8950 100644 --- a/tests/read.rs +++ b/tests/read.rs @@ -72,7 +72,7 @@ mod tests { let file = File::open("tests/resources/test.txt").unwrap(); let mut limiter = Limiter::new( file, - Some(LimiterOptions::new(100, Duration::from_secs(1), 1000)), + Some(LimiterOptions::new(1000, Duration::from_secs(1), 1000)), None, ); assert!(limiter.limits().0); diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index cebdb6b..7ca840c 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,6 +1,8 @@ use hex_literal::hex; use sha2::Digest; +pub mod paramtests; + // The checksum and the size of the data (to trim the buffer) pub const FILE_BIG: ([u8; 32], usize) = ( hex!("55e28ecbd9ea1df018ffacd137ee8d62551eb2d6fbd46508bca7809005ff267a"), @@ -27,10 +29,10 @@ pub fn assert_checksum_samedata(buf: &[u8], data: u8) { if N <= 50 { println!("{:?}\n{:?}", samedata, buf); } - hasher.update(&[data; N]); + hasher.update([data; N]); let samedata_hash = hasher.finalize(); let mut hasher = sha2::Sha256::new(); - hasher.update(&buf); + hasher.update(buf); assert_eq!(hasher.finalize()[..], samedata_hash[..]); } diff --git a/tests/utils/paramtests.rs b/tests/utils/paramtests.rs new file mode 100644 index 0000000..35312f9 --- /dev/null +++ b/tests/utils/paramtests.rs @@ -0,0 +1,29 @@ +use std::time::Duration; + +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + +#[allow(dead_code)] +pub fn start_parametric_test(nbiter: usize, regressions: Vec, function: F) +where + F: Fn(SmallRng), +{ + #[cfg(feature = "heavy_testing")] + let nbiter = nbiter * 100; + + for seed in regressions.iter() { + println!("Test regression seed {}", seed); + function(SmallRng::seed_from_u64(*seed)); + std::thread::sleep(Duration::from_millis(50)); + } + let mut seeder = SmallRng::from_entropy(); + + let nspace = nbiter.to_string().len(); + for n in 0..nbiter { + let new_seed: u64 = seeder.gen(); + print!("{:1$}", n + 1, nspace); + println!("/{}| Seed {:20}", nbiter, new_seed); + function(SmallRng::seed_from_u64(new_seed)); + std::thread::sleep(Duration::from_millis(50)); + } +}