From 53264cc6c24912fd4220974f5292f7d0fa41d826 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Ko=C5=82aczkowski?= Date: Mon, 22 Jul 2024 14:23:55 +0200 Subject: [PATCH] Use parse_duration for parsing request timeouts and retry delays --- src/config.rs | 65 ++++++++++++++++++-------------------------------- src/context.rs | 32 +++++++++++-------------- src/report.rs | 14 +++++------ 3 files changed, 44 insertions(+), 67 deletions(-) diff --git a/src/config.rs b/src/config.rs index aabbf92..e786ba1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,14 +1,14 @@ -use std::collections::HashMap; -use std::error::Error; -use std::num::NonZeroUsize; -use std::path::PathBuf; -use std::str::FromStr; - use anyhow::anyhow; use chrono::Utc; use clap::builder::PossibleValue; use clap::{Parser, ValueEnum}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::error::Error; +use std::num::NonZeroUsize; +use std::path::PathBuf; +use std::str::FromStr; +use std::time::Duration; /// Parse a single key-value pair fn parse_key_val(s: &str) -> Result<(T, U), anyhow::Error> @@ -91,51 +91,32 @@ impl FromStr for Interval { /// Controls the min and max retry interval for retry mechanism #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] -pub struct RetryInterval { - pub min_ms: u64, - pub max_ms: u64, +pub struct RetryDelay { + pub min: Duration, + pub max: Duration, } -impl RetryInterval { +impl RetryDelay { pub fn new(time: &str) -> Option { let values: Vec<&str> = time.split(',').collect(); if values.len() > 2 { return None; } - let min_ms = RetryInterval::parse_time(values.first().unwrap_or(&""))?; - let max_ms = RetryInterval::parse_time(values.get(1).unwrap_or(&"")).unwrap_or(min_ms); - if min_ms > max_ms { + let min = parse_duration::parse(values.first().unwrap_or(&"")).ok()?; + let max = parse_duration::parse(values.get(1).unwrap_or(&"")).unwrap_or(min); + if min > max { None } else { - Some(RetryInterval { min_ms, max_ms }) - } - } - - fn parse_time(time: &str) -> Option { - let trimmed_time = time.trim(); - if trimmed_time.is_empty() { - return None; + Some(RetryDelay { min, max }) } - - let value_str = match trimmed_time { - s if s.ends_with("ms") => s.trim_end_matches("ms"), - s if s.ends_with('s') => { - let num = s.trim_end_matches('s').parse::().ok()?; - return Some(num * 1000); - } - _ => trimmed_time, - }; - - let value = value_str.trim().parse::().ok()?; - Some(value) } } -impl FromStr for RetryInterval { +impl FromStr for RetryDelay { type Err = String; fn from_str(s: &str) -> Result { - if let Some(interval) = RetryInterval::new(s) { + if let Some(interval) = RetryDelay::new(s) { Ok(interval) } else { Err(concat!( @@ -194,18 +175,18 @@ pub struct ConnectionConf { #[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")] pub consistency: Consistency, - #[clap(long("request-timeout"), default_value = "5", value_name = "COUNT")] - pub request_timeout: NonZeroUsize, + #[clap(long("request-timeout"), default_value = "5s", value_name = "DURATION", value_parser = parse_duration::parse)] + pub request_timeout: Duration, - #[clap(long("retry-number"), default_value = "10", value_name = "COUNT")] - pub retry_number: u64, + #[clap(long("retries"), default_value = "3", value_name = "COUNT")] + pub retries: u64, #[clap( - long("retry-interval"), + long("retry-delay"), default_value = "100ms,5s", - value_name = "TIME[,TIME]" + value_name = "MIN[,MAX]" )] - pub retry_interval: RetryInterval, + pub retry_interval: RetryDelay, } #[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)] diff --git a/src/context.rs b/src/context.rs index bf93480..bb1e361 100644 --- a/src/context.rs +++ b/src/context.rs @@ -39,7 +39,7 @@ use tokio::time::{Duration, Instant}; use try_lock::TryLock; use uuid::{Variant, Version}; -use crate::config::{ConnectionConf, RetryInterval}; +use crate::config::{ConnectionConf, RetryDelay}; use crate::LatteError; fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> { @@ -71,7 +71,7 @@ pub async fn connect(conf: &ConnectionConf) -> Result { let profile = ExecutionProfile::builder() .consistency(conf.consistency.scylla_consistency()) .load_balancing_policy(policy_builder.build()) - .request_timeout(Some(Duration::from_secs(conf.request_timeout.get() as u64))) + .request_timeout(Some(conf.request_timeout)) .build(); let scylla_session = SessionBuilder::new() @@ -85,7 +85,7 @@ pub async fn connect(conf: &ConnectionConf) -> Result { .map_err(|e| CassError(CassErrorKind::FailedToConnect(conf.addresses.clone(), e)))?; Ok(Context::new( scylla_session, - conf.retry_number, + conf.retries, conf.retry_interval, )) } @@ -369,11 +369,11 @@ impl Default for SessionStats { } pub fn get_exponential_retry_interval( - min_interval: u64, - max_interval: u64, + min_interval: Duration, + max_interval: Duration, current_attempt_num: u64, -) -> u64 { - let min_interval_float: f64 = min_interval as f64; +) -> Duration { + let min_interval_float: f64 = min_interval.as_secs_f64(); let mut current_interval: f64 = min_interval_float * (2u64.pow(current_attempt_num.try_into().unwrap_or(0)) as f64); @@ -381,7 +381,7 @@ pub fn get_exponential_retry_interval( current_interval += random::() * min_interval_float; current_interval -= min_interval_float / 2.0; - std::cmp::min(current_interval as u64, max_interval) + Duration::from_secs_f64(current_interval.min(max_interval.as_secs_f64())) } /// This is the main object that a workload script uses to interface with the outside world. @@ -392,7 +392,7 @@ pub struct Context { statements: HashMap>, stats: TryLock, retry_number: u64, - retry_interval: RetryInterval, + retry_interval: RetryDelay, #[rune(get, set, add_assign, copy)] pub load_cycle_count: u64, #[rune(get)] @@ -409,11 +409,7 @@ unsafe impl Send for Context {} unsafe impl Sync for Context {} impl Context { - pub fn new( - session: scylla::Session, - retry_number: u64, - retry_interval: RetryInterval, - ) -> Context { + pub fn new(session: scylla::Session, retry_number: u64, retry_interval: RetryDelay) -> Context { Context { session: Arc::new(session), statements: HashMap::new(), @@ -506,14 +502,14 @@ impl Context { let mut rs: Result = Err(QueryError::TimeoutError); let mut attempts = 0; - while attempts <= self.retry_number + 1 && Self::should_retry(&rs) { + while attempts <= self.retry_number && Self::should_retry(&rs) { if attempts > 0 { let current_retry_interval = get_exponential_retry_interval( - self.retry_interval.min_ms, - self.retry_interval.max_ms, + self.retry_interval.min, + self.retry_interval.max, attempts, ); - tokio::time::sleep(Duration::from_millis(current_retry_interval)).await; + tokio::time::sleep(current_retry_interval).await; } rs = f().await; attempts += 1; diff --git a/src/report.rs b/src/report.rs index 42c64b8..3c226dc 100644 --- a/src/report.rs +++ b/src/report.rs @@ -564,17 +564,17 @@ impl<'a> Display for RunConfigCmp<'a> { self.line("└─", "op", |conf| { Quantity::from(conf.sampling_interval.count()) }), - self.line("Request timeout", "", |conf| { - Quantity::from(conf.connection.request_timeout) + self.line("Request timeout", "s", |conf| { + Quantity::from(conf.connection.request_timeout.as_secs_f64()) }), self.line("Retries", "", |conf| { - Quantity::from(conf.connection.retry_number) + Quantity::from(conf.connection.retries) }), - self.line("├─ min interval", "ms", |conf| { - Quantity::from(conf.connection.retry_interval.min_ms) + self.line("├─ min delay", "ms", |conf| { + Quantity::from(conf.connection.retry_interval.min.as_secs_f64() * 1000.0) }), - self.line("└─ max interval", "ms", |conf| { - Quantity::from(conf.connection.retry_interval.max_ms) + self.line("└─ max delay", "ms", |conf| { + Quantity::from(conf.connection.retry_interval.max.as_secs_f64() * 1000.0) }), ];