diff --git a/pkcs11/src/api/mod.rs b/pkcs11/src/api/mod.rs index 7ad2858..a9a8029 100644 --- a/pkcs11/src/api/mod.rs +++ b/pkcs11/src/api/mod.rs @@ -16,6 +16,7 @@ pub mod verify; use std::sync::atomic::Ordering; use std::{ptr::addr_of_mut, sync::Arc}; +use crate::config::device::{RetryThreadMessage, RETRY_THREAD}; use crate::{ backend::events::{fetch_slots_state, EventsManager}, data::{self, DEVICE, EVENTS_MANAGER, THREADS_ALLOWED, TOKENS_STATE}, @@ -110,6 +111,9 @@ pub extern "C" fn C_Finalize(pReserved: CK_VOID_PTR) -> CK_RV { return cryptoki_sys::CKR_ARGUMENTS_BAD; } DEVICE.store(None); + if THREADS_ALLOWED.load(Ordering::Relaxed) { + RETRY_THREAD.send(RetryThreadMessage::Finalize).unwrap(); + } EVENTS_MANAGER.write().unwrap().finalized = true; cryptoki_sys::CKR_OK diff --git a/pkcs11/src/backend/login.rs b/pkcs11/src/backend/login.rs index 72fad12..80c0df7 100644 --- a/pkcs11/src/backend/login.rs +++ b/pkcs11/src/backend/login.rs @@ -15,13 +15,36 @@ use std::{ time::Duration, }; -use crate::config::{ - config_file::{RetryConfig, UserConfig}, - device::{InstanceAttempt, InstanceData, Slot}, +use crate::{ + config::{ + config_file::{RetryConfig, UserConfig}, + device::{InstanceAttempt, InstanceData, Slot}, + }, + data::THREADS_ALLOWED, }; use super::{ApiError, Error}; +#[derive(Debug)] +enum ShouldHealthCheck { + /// The instance is ready to be used + RunDirectly, + /// The instance needs to first be health checked + HealthCheckFirst, +} + +impl ShouldHealthCheck { + fn should_check(&self) -> bool { + matches!(self, ShouldHealthCheck::HealthCheckFirst) + } +} + +#[derive(Debug, Clone, Copy)] +enum HealthCheck { + Possible, + Avoid, +} + #[derive(Debug)] pub struct LoginCtx { slot: Arc, @@ -71,6 +94,35 @@ impl std::fmt::Display for LoginError { } } +/// Perform a health check with a timeout of 1 second +fn health_check_get_timeout(instance: &InstanceData) -> bool { + instance.config.client.clear_pool(); + let config = &instance.config; + let uri_str = format!("{}/health/ready", config.base_path); + let mut req = config.client.get(&uri_str).timeout(Duration::from_secs(1)); + if let Some(ref user_agent) = config.user_agent { + req = req.set("user-agent", user_agent); + } + + match req.call() { + Ok(r) => { + if r.status() == 200 { + instance.clear_failed(); + return true; + } + log::warn!("Failed retry {}", r.status_text()); + instance.bump_failed(); + false + } + + Err(err) => { + log::warn!("Failed retry {err:?}"); + instance.bump_failed(); + false + } + } +} + impl LoginCtx { pub fn new(slot: Arc, admin_allowed: bool, operator_allowed: bool) -> Self { let mut ck_state = CKS_RO_PUBLIC_SESSION; @@ -113,7 +165,7 @@ impl LoginCtx { pub fn login(&mut self, user_type: CK_USER_TYPE, pin: String) -> Result<(), LoginError> { trace!("Login as {:?} with pin", user_type); - let expected = match user_type { + let (user_status, user_mode) = match user_type { CKU_CONTEXT_SPECIFIC => return Err(LoginError::InvalidUser), CKU_SO => { trace!("administrator: {:?}", self.slot.administrator); @@ -126,7 +178,7 @@ impl LoginCtx { }), }; self.admin_allowed = true; - (UserStatus::Administrator, self.administrator()) + (UserStatus::Administrator, UserMode::Administrator) } CKU_USER => { self.operator_login_override = match self.operator_config() { @@ -137,46 +189,66 @@ impl LoginCtx { }), }; self.operator_allowed = true; - (UserStatus::Operator, self.operator()) + (UserStatus::Operator, UserMode::Operator) } _ => return Err(LoginError::BadArgument), }; - trace!("Config: {:?}", expected.1); + let got_user = self + .try_(get_current_user_status, user_mode) + .map_err(|err| { + error!("Login check failed: {err:?}"); + LoginError::UserNotPresent + })?; - let config = expected.1.ok_or(LoginError::UserNotPresent)?.config; - - if get_current_user_status(&config) == expected.0 { - self.ck_state = match expected.0 { + if got_user == user_status { + self.ck_state = match user_status { UserStatus::Operator => CKS_RW_USER_FUNCTIONS, UserStatus::Administrator => CKS_RW_SO_FUNCTIONS, UserStatus::LoggedOut => CKS_RO_PUBLIC_SESSION, }; Ok(()) } else { - error!("Failed to login as {:?} with pin", expected.0); + error!("Failed to login as {user_mode:?} with pin, got user {got_user:?}"); Err(LoginError::IncorrectPin) } } - fn next_instance(&self) -> &InstanceData { + fn next_instance( + &self, + accept_health_check: HealthCheck, + ) -> (&InstanceData, ShouldHealthCheck) { + let threads_allowed = THREADS_ALLOWED.load(Relaxed); let index = self.slot.instance_balancer.fetch_add(1, Relaxed); let index = index % self.slot.instances.len(); let instance = &self.slot.instances[index]; - match instance.should_try() { - InstanceAttempt::Failed => {} - InstanceAttempt::Working | InstanceAttempt::Retry => return instance, + match (instance.should_try(), threads_allowed, accept_health_check) { + (InstanceAttempt::Failed, _, _) + | (InstanceAttempt::Retry, true, _) + | (InstanceAttempt::Retry, false, HealthCheck::Avoid) => {} + (InstanceAttempt::Working, _, _) => return (instance, ShouldHealthCheck::RunDirectly), + (InstanceAttempt::Retry, false, HealthCheck::Possible) => { + return (instance, ShouldHealthCheck::HealthCheckFirst) + } } for i in 0..self.slot.instances.len() - 1 { let instance = &self.slot.instances[index + i]; - match instance.should_try() { - InstanceAttempt::Failed => continue, - InstanceAttempt::Working | InstanceAttempt::Retry => { + match (instance.should_try(), threads_allowed, accept_health_check) { + (InstanceAttempt::Failed, _, _) + | (InstanceAttempt::Retry, true, _) + | (InstanceAttempt::Retry, false, HealthCheck::Avoid) => continue, + (InstanceAttempt::Working, _, _) => { // This not true round-robin in case of multithreaded acces // This is degraded mode so best-effort is attempted at best self.slot.instance_balancer.fetch_add(i, Relaxed); - return instance; + return (instance, ShouldHealthCheck::RunDirectly); + } + (InstanceAttempt::Retry, false, HealthCheck::Possible) => { + // This not true round-robin in case of multithreaded acces + // This is degraded mode so best-effort is attempted at best + self.slot.instance_balancer.fetch_add(i, Relaxed); + return (instance, ShouldHealthCheck::HealthCheckFirst); } } } @@ -184,23 +256,36 @@ impl LoginCtx { // No instance is valid, return a failed instance for an attempt let index = self.slot.instance_balancer.fetch_add(1, Relaxed); let index = index % self.slot.instances.len(); - &self.slot.instances[index] + // Instance is not valid, don't try health check, it would only slow things down + (&self.slot.instances[index], ShouldHealthCheck::RunDirectly) } - fn operator(&self) -> Option { - get_user_api_config(self.operator_config(), self.next_instance()) + fn operator( + &self, + accept_health_check: HealthCheck, + ) -> Option<(InstanceData, ShouldHealthCheck)> { + let (instance, should_health_check) = self.next_instance(accept_health_check); + get_user_api_config(self.operator_config(), instance).map(|c| (c, should_health_check)) } - fn administrator(&self) -> Option { - get_user_api_config(self.admin_config(), self.next_instance()) + fn administrator( + &self, + accept_health_check: HealthCheck, + ) -> Option<(InstanceData, ShouldHealthCheck)> { + let (instance, should_health_check) = self.next_instance(accept_health_check); + get_user_api_config(self.admin_config(), instance).map(|c| (c, should_health_check)) } - fn operator_or_administrator(&self) -> Option { - self.operator().or_else(|| self.administrator()) + fn operator_or_administrator( + &self, + accept_health_check: HealthCheck, + ) -> Option<(InstanceData, ShouldHealthCheck)> { + self.operator(accept_health_check) + .or_else(|| self.administrator(accept_health_check)) } - fn guest(&self) -> &InstanceData { - self.next_instance() + fn guest(&self, accept_health_check: HealthCheck) -> (&InstanceData, ShouldHealthCheck) { + self.next_instance(accept_health_check) } pub fn can_run_mode(&self, mode: UserMode) -> bool { @@ -225,12 +310,21 @@ impl LoginCtx { self.ck_state = CKS_RO_PUBLIC_SESSION; } - pub fn get_config_user_mode(&self, user_mode: &UserMode) -> Option { + fn get_config_user_mode( + &self, + user_mode: &UserMode, + accept_health_check: HealthCheck, + ) -> Option<(InstanceData, ShouldHealthCheck)> { match user_mode { - UserMode::Operator => self.operator(), - UserMode::Administrator => self.administrator(), - UserMode::Guest => Some(self.guest().clone()), - UserMode::OperatorOrAdministrator => self.operator_or_administrator(), + UserMode::Operator => self.operator(accept_health_check), + UserMode::Administrator => self.administrator(accept_health_check), + UserMode::Guest => { + let (instance, should_health_check) = self.guest(accept_health_check); + Some((instance.clone(), should_health_check)) + } + UserMode::OperatorOrAdministrator => { + self.operator_or_administrator(accept_health_check) + } } } @@ -239,8 +333,11 @@ impl LoginCtx { where F: FnOnce(&Configuration) -> Result> + Clone, { + let mut health_check_count = 0; // we loop for a maximum of instances.len() times - let Some(mut instance) = self.get_config_user_mode(&user_mode) else { + let Some((mut instance, mut should_health_check)) = + self.get_config_user_mode(&user_mode, HealthCheck::Possible) + else { return Err(Error::Login(LoginError::UserNotPresent)); }; @@ -256,12 +353,30 @@ impl LoginCtx { let delay = Duration::from_secs(delay_seconds); loop { + let accept_health_check = if health_check_count < 3 { + HealthCheck::Possible + } else { + HealthCheck::Avoid + }; if retry_count > retry_limit { error!( "Retry count exceeded after {retry_limit} attempts, instance is unreachable" ); return Err(ApiError::InstanceRemoved.into()); } + + if should_health_check.should_check() && !health_check_get_timeout(&instance) { + health_check_count += 1; + // Instance is not valid, we try the next one + if let Some((new_instance, new_should_health_check)) = + self.get_config_user_mode(&user_mode, accept_health_check) + { + instance = new_instance; + should_health_check = new_should_health_check; + } + continue; + } + retry_count += 1; let api_call_clone = api_call.clone(); match api_call_clone(&instance.config) { @@ -280,8 +395,11 @@ impl LoginCtx { warn!("Connection attempt {retry_count} failed: Status error connecting to the instance, {:?}, retrying in {delay_seconds}s", err.status); thread::sleep(delay); - if let Some(new_conf) = self.get_config_user_mode(&user_mode) { - instance = new_conf; + if let Some((new_instance, new_should_health_check)) = + self.get_config_user_mode(&user_mode, accept_health_check) + { + instance = new_instance; + should_health_check = new_should_health_check; } } @@ -296,8 +414,11 @@ impl LoginCtx { instance.bump_failed(); warn!("Connection attempt {retry_count} failed: IO error connecting to the instance, {err}, retrying in {delay_seconds}s"); thread::sleep(delay); - if let Some(new_conf) = self.get_config_user_mode(&user_mode) { - instance = new_conf; + if let Some((new_instance, new_should_health_check)) = + self.get_config_user_mode(&user_mode, accept_health_check) + { + instance = new_instance; + should_health_check = new_should_health_check; } } // Otherwise, return the error @@ -349,7 +470,7 @@ impl LoginCtx { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum UserMode { Operator, Administrator, @@ -366,29 +487,23 @@ pub enum UserStatus { pub fn get_current_user_status( api_config: &nethsm_sdk_rs::apis::configuration::Configuration, -) -> UserStatus { +) -> Result> { let auth = match api_config.basic_auth.as_ref() { Some(auth) => auth, - None => return UserStatus::LoggedOut, + None => return Ok(UserStatus::LoggedOut), }; if auth.1.is_none() { - return UserStatus::LoggedOut; + return Ok(UserStatus::LoggedOut); } - let user = match default_api::users_user_id_get(api_config, auth.0.as_str()) { - Ok(user) => user.entity, - Err(err) => { - error!("Failed to get user: {:?}", err); - return UserStatus::LoggedOut; - } - }; + let user = default_api::users_user_id_get(api_config, auth.0.as_str())?; - match user.role { + Ok(match user.entity.role { UserRole::Operator => UserStatus::Operator, UserRole::Administrator => UserStatus::Administrator, _ => UserStatus::LoggedOut, - } + }) } // Check if the user is logged in and then return the configuration to connect as this user fn get_user_api_config( diff --git a/pkcs11/src/config/device.rs b/pkcs11/src/config/device.rs index 324612b..3fad626 100644 --- a/pkcs11/src/config/device.rs +++ b/pkcs11/src/config/device.rs @@ -2,8 +2,8 @@ use std::{ collections::BTreeMap, sync::{ atomic::{AtomicUsize, Ordering::Relaxed}, - mpsc::{self, RecvTimeoutError}, - Arc, Condvar, LazyLock, Mutex, RwLock, + mpsc::{self, RecvError, RecvTimeoutError}, + Arc, Condvar, LazyLock, Mutex, RwLock, Weak, }, thread, time::{Duration, Instant}, @@ -15,7 +15,17 @@ use crate::{backend::db::Db, data::THREADS_ALLOWED}; use super::config_file::{RetryConfig, UserConfig}; -static RETRY_THREAD: LazyLock> = LazyLock::new(|| { +#[allow(clippy::large_enum_variant)] +pub enum RetryThreadMessage { + FailedInstnace { + retry_in: Duration, + instance: InstanceData, + }, + /// The device is being removed, clear all connections + Finalize, +} + +pub static RETRY_THREAD: LazyLock> = LazyLock::new(|| { let (tx, rx) = mpsc::channel(); let (tx_instance, rx_instance) = mpsc::channel(); thread::spawn(background_thread(rx_instance)); @@ -24,34 +34,43 @@ static RETRY_THREAD: LazyLock> = LazyLock }); fn background_timer( - rx: mpsc::Receiver<(Duration, InstanceData)>, + rx: mpsc::Receiver, tx_instance: mpsc::Sender, ) -> impl FnOnce() { - let mut jobs: BTreeMap = BTreeMap::new(); + let mut jobs: BTreeMap = BTreeMap::new(); move || loop { let next_job = jobs.pop_first(); let Some((next_job_deadline, next_job_instance)) = next_job else { // No jobs in the queue, we can just run the next - let Ok((new_job_duration, new_state)) = rx.recv() else { - return; - }; - - jobs.insert(Instant::now() + new_job_duration, new_state); - continue; + match rx.recv() { + Err(RecvError) => break, + Ok(RetryThreadMessage::Finalize) => continue, + Ok(RetryThreadMessage::FailedInstnace { retry_in, instance }) => { + jobs.insert(Instant::now() + retry_in, instance.into()); + continue; + } + } }; let now = Instant::now(); if now >= next_job_deadline { - tx_instance.send(next_job_instance).unwrap(); - continue; + if let Some(instance) = next_job_instance.upgrade() { + tx_instance.send(instance).unwrap(); + continue; + } + } else { + jobs.insert(next_job_deadline, next_job_instance); } - jobs.insert(next_job_deadline, next_job_instance); let timeout = next_job_deadline.duration_since(now); match rx.recv_timeout(timeout) { - Ok((run_in, new_instance)) => { - jobs.insert(now + run_in, new_instance); + Ok(RetryThreadMessage::Finalize) => { + jobs.clear(); + continue; + } + Ok(RetryThreadMessage::FailedInstnace { retry_in, instance }) => { + jobs.insert(now + retry_in, instance.into()); continue; } Err(RecvTimeoutError::Timeout) => continue, @@ -63,6 +82,7 @@ fn background_timer( fn background_thread(rx: mpsc::Receiver) -> impl FnOnce() { move || loop { while let Ok(instance) = rx.recv() { + instance.config.client.clear_pool(); match health_ready_get(&instance.config) { Ok(_) => instance.clear_failed(), Err(_) => instance.bump_failed(), @@ -103,6 +123,31 @@ pub struct InstanceData { pub state: Arc>, } +#[derive(Debug, Clone)] +pub struct WeakInstanceData { + pub config: Configuration, + pub state: Weak>, +} + +impl From for WeakInstanceData { + fn from(value: InstanceData) -> Self { + Self { + config: value.config, + state: Arc::downgrade(&value.state), + } + } +} + +impl WeakInstanceData { + fn upgrade(self) -> Option { + let state = self.state.upgrade()?; + Some(InstanceData { + config: self.config, + state, + }) + } +} + pub enum InstanceAttempt { /// The instance is in the failed state and should not be used Failed, @@ -163,7 +208,10 @@ impl InstanceData { drop(write); if THREADS_ALLOWED.load(Relaxed) { RETRY_THREAD - .send((retry_duration_from_count(retry_count), self.clone())) + .send(RetryThreadMessage::FailedInstnace { + retry_in: retry_duration_from_count(retry_count), + instance: self.clone(), + }) .ok(); } } diff --git a/pkcs11/tests/tools/mod.rs b/pkcs11/tests/tools/mod.rs index 09b8feb..87583dd 100644 --- a/pkcs11/tests/tools/mod.rs +++ b/pkcs11/tests/tools/mod.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use std::io::BufWriter; use std::mem; use std::net::Ipv4Addr; +use std::ptr; use std::sync::{Arc, LazyLock, Mutex, MutexGuard}; use std::thread::sleep; use std::time::Duration; @@ -16,7 +17,7 @@ use nethsm_sdk_rs::{ }, models::{ProvisionRequestData, UserPostData, UserRole}, }; -use pkcs11::Ctx; +use pkcs11::{types::CK_C_INITIALIZE_ARGS, Ctx}; use rustls::{ client::danger::ServerCertVerifier, crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider}, @@ -197,8 +198,8 @@ impl TestContext { } } -impl Drop for TestDropper { - fn drop(&mut self) { +impl TestDropper { + fn clear(&mut self) { for p in self.context.blocked_ports.iter().cloned() { TestContext::unblock(p); } @@ -206,26 +207,45 @@ impl Drop for TestDropper { } } -static PROXY_SENDER: LazyLock)>> = - LazyLock::new(|| { - let (tx, mut rx) = unbounded_channel(); - std::thread::spawn(move || { - runtime::Builder::new_current_thread() - .enable_io() - .build() - .unwrap() - .block_on(async move { - let mut tasks = Vec::new(); - while let Some((from_port, to_port, sender)) = rx.recv().await { - tasks.push(tokio::spawn(proxy(from_port, to_port, sender))); - } - for task in tasks { - task.abort(); +impl Drop for TestDropper { + fn drop(&mut self) { + self.clear(); + } +} + +enum ProxyMessage { + NewProxy(u16, u16, broadcast::Sender<()>), + CloseAll, +} + +static PROXY_SENDER: LazyLock> = LazyLock::new(|| { + let (tx, mut rx) = unbounded_channel(); + std::thread::spawn(move || { + runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap() + .block_on(async move { + let mut tasks = Vec::new(); + while let Some(msg) = rx.recv().await { + match msg { + ProxyMessage::NewProxy(from_port, to_port, sender) => { + tasks.push(tokio::spawn(proxy(from_port, to_port, sender))) + } + ProxyMessage::CloseAll => { + for task in mem::take(&mut tasks) { + task.abort(); + } + } } - }) - }); - tx + } + for task in tasks { + task.abort(); + } + }) }); + tx +}); async fn proxy(from_port: u16, to_port: u16, stall_sender: broadcast::Sender<()>) { let listener = TcpListener::bind(((Ipv4Addr::from([127, 0, 0, 1])), from_port)) @@ -299,7 +319,7 @@ static DOCKER_HELD: Mutex = Mutex::new(false); pub fn run_tests( proxies: &[(u16, u16)], config: P11Config, - f: impl FnOnce(&mut TestContext, &mut Ctx), + f: impl FnOnce(&mut TestContext, &mut Ctx) + Clone, ) { let Ok(serialize_test) = DOCKER_HELD.lock() else { eprintln!("Test not run"); @@ -368,7 +388,7 @@ pub fn run_tests( for (in_port, out_port) in proxies { PROXY_SENDER - .send(( + .send(ProxyMessage::NewProxy( *in_port, *out_port, test_dropper.context.stall_connections.clone(), @@ -381,8 +401,26 @@ pub fn run_tests( serde_yaml::to_writer(BufWriter::new(tmpfile.as_file_mut()), &config).unwrap(); let path = tmpfile.path(); set_var(config_file::ENV_VAR_CONFIG_FILE, path); - let mut ctx = Ctx::new_and_initialize("../target/release/libnethsm_pkcs11.so").unwrap(); - f(&mut test_dropper.context, &mut ctx); - ctx.close_all_sessions(0).unwrap(); + { + let mut ctx = Ctx::new_and_initialize("../target/release/libnethsm_pkcs11.so").unwrap(); + let f_cl = f.clone(); + f_cl(&mut test_dropper.context, &mut ctx); + ctx.close_all_sessions(0).unwrap(); + } + { + let mut ctx = Ctx::new("../target/release/libnethsm_pkcs11.so").unwrap(); + ctx.initialize(Some(CK_C_INITIALIZE_ARGS { + CreateMutex: None, + DestroyMutex: None, + LockMutex: None, + UnlockMutex: None, + flags: cryptoki_sys::CKF_LIBRARY_CANT_CREATE_OS_THREADS, + pReserved: ptr::null_mut(), + })) + .unwrap(); + f(&mut test_dropper.context, &mut ctx); + ctx.close_all_sessions(0).unwrap(); + } + PROXY_SENDER.send(ProxyMessage::CloseAll).unwrap(); println!("Ending test"); }