diff --git a/Cargo.lock b/Cargo.lock index 6792dc9224..2ec3a7e58a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4691,13 +4691,13 @@ dependencies = [ "intercept", "pico-args", "ractor", - "ractor-supervisor", "sentry", "serde", "serde_json", "specta", "specta-typescript", "strum 0.27.2", + "supervisor", "tauri", "tauri-build", "tauri-plugin-analytics", @@ -14393,32 +14393,23 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ractor" -version = "0.14.7" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d65972a0286ef14c43c6daafbac6cf15e96496446147683b2905292c35cc178" +checksum = "6102314f700f3e8df466c49110830b18cbfc172f88f27a9d7383e455663b1be7" dependencies = [ "async-trait", "bon 2.3.0", "dashmap", "futures", + "js-sys", "once_cell", "strum 0.26.3", "tokio", + "tokio_with_wasm", "tracing", -] - -[[package]] -name = "ractor-supervisor" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d90830688ebfafdc226f3c9567c40fecf4c51a7513171181102ae66e4b57c15f" -dependencies = [ - "futures-util", - "if_chain", - "log", - "ractor", - "thiserror 2.0.18", - "uuid", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-time", ] [[package]] @@ -17506,6 +17497,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "supervisor" +version = "0.1.0" +dependencies = [ + "ractor", + "thiserror 2.0.18", + "tokio", + "tracing", +] + [[package]] name = "svgtypes" version = "0.15.3" @@ -18910,7 +18911,6 @@ dependencies = [ "quickcheck", "quickcheck_macros", "ractor", - "ractor-supervisor", "rodio", "sentry", "serde", @@ -18918,6 +18918,7 @@ dependencies = [ "specta", "specta-typescript", "strum 0.27.2", + "supervisor", "tauri", "tauri-plugin", "tauri-plugin-fs-sync", @@ -19023,7 +19024,6 @@ dependencies = [ "port-killer", "port_check", "ractor", - "ractor-supervisor", "reqwest 0.13.2", "rodio", "serde", @@ -19032,6 +19032,7 @@ dependencies = [ "specta", "specta-typescript", "strum 0.27.2", + "supervisor", "tauri", "tauri-plugin", "tauri-plugin-settings", @@ -19090,7 +19091,6 @@ name = "tauri-plugin-network" version = "0.1.0" dependencies = [ "ractor", - "ractor-supervisor", "reqwest 0.13.2", "serde", "specta", @@ -20395,6 +20395,30 @@ dependencies = [ "webpki-roots 0.26.11", ] +[[package]] +name = "tokio_with_wasm" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34e40fbbbd95441133fe9483f522db15dbfd26dc636164ebd8f2dd28759a6aa6" +dependencies = [ + "js-sys", + "tokio", + "tokio_with_wasm_proc", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "tokio_with_wasm_proc" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d01145a2c788d6aae4cd653afec1e8332534d7d783d01897cefcafe4428de992" +dependencies = [ + "quote", + "syn 2.0.114", +] + [[package]] name = "toml" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index 086fc7c49d..f9091e569f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -198,8 +198,8 @@ async-stream = "0.3.6" futures-channel = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" -ractor = "0.14" -ractor-supervisor = "0.1.9" +hypr-supervisor = { path = "crates/supervisor", package = "supervisor" } +ractor = "0.15.10" rayon = "1.11" reqwest = "0.13" reqwest-middleware = "0.5" diff --git a/apps/desktop/src-tauri/Cargo.toml b/apps/desktop/src-tauri/Cargo.toml index fe70712dc5..16c940f4e0 100644 --- a/apps/desktop/src-tauri/Cargo.toml +++ b/apps/desktop/src-tauri/Cargo.toml @@ -94,8 +94,8 @@ tracing = { workspace = true } hypr-host = { workspace = true } +hypr-supervisor = { workspace = true } ractor = { workspace = true } -ractor-supervisor = { workspace = true } pico-args = "0.5" tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/apps/desktop/src-tauri/src/supervisor.rs b/apps/desktop/src-tauri/src/supervisor.rs index b12d913795..cd8cb1c8df 100644 --- a/apps/desktop/src-tauri/src/supervisor.rs +++ b/apps/desktop/src-tauri/src/supervisor.rs @@ -1,11 +1,9 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; +use hypr_supervisor::dynamic::{DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions}; use ractor::ActorRef; use ractor::concurrency::Duration; -use ractor_supervisor::dynamic::{ - DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions, -}; pub type SupervisorRef = ActorRef; pub type SupervisorHandle = tokio::task::JoinHandle<()>; diff --git a/crates/supervisor/Cargo.toml b/crates/supervisor/Cargo.toml new file mode 100644 index 0000000000..abe4447dad --- /dev/null +++ b/crates/supervisor/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "supervisor" +version = "0.1.0" +edition = "2024" + +[dependencies] +ractor = { workspace = true, features = ["async-trait"] } +thiserror = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } diff --git a/crates/supervisor/src/dynamic.rs b/crates/supervisor/src/dynamic.rs new file mode 100644 index 0000000000..f6f864ba2d --- /dev/null +++ b/crates/supervisor/src/dynamic.rs @@ -0,0 +1,806 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use ractor::{ + Actor, ActorCell, ActorId, ActorProcessingErr, ActorRef, RpcReplyPort, SpawnErr, + SupervisionEvent, concurrency::JoinHandle, +}; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, thiserror::Error)] +pub enum SupervisorError { + #[error("Child '{child_id}' not found in specs")] + ChildNotFound { child_id: String }, + + #[error("Child '{pid}' does not have a name set")] + ChildNameNotSet { pid: ActorId }, + + #[error("Max children exceeded")] + MaxChildrenExceeded, + + #[error("Meltdown: {reason}")] + Meltdown { reason: String }, +} + +pub type DynSpawnFuture = Pin> + Send>>; + +#[derive(Clone)] +pub struct DynSpawnFn(Arc DynSpawnFuture + Send + Sync>); + +impl DynSpawnFn { + pub fn new(f: F) -> Self + where + F: Fn(ActorCell, String) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + Self(Arc::new(move |cell, id| Box::pin(f(cell, id)))) + } + + pub async fn call(&self, sup: ActorCell, id: String) -> Result { + (self.0)(sup, id).await + } +} + +type BackoffFn = dyn Fn(&str, usize, Instant, Option) -> Option + Send + Sync; + +#[derive(Clone)] +pub struct ChildBackoffFn(Arc); + +impl ChildBackoffFn { + pub fn new(f: F) -> Self + where + F: Fn(&str, usize, Instant, Option) -> Option + Send + Sync + 'static, + { + Self(Arc::new(f)) + } + + fn call( + &self, + child_id: &str, + restart_count: usize, + last_fail: Instant, + reset_after: Option, + ) -> Option { + (self.0)(child_id, restart_count, last_fail, reset_after) + } +} + +#[derive(Clone)] +pub struct DynChildSpec { + pub id: String, + pub restart: crate::RestartPolicy, + pub spawn_fn: DynSpawnFn, + pub backoff_fn: Option, + pub reset_after: Option, +} + +#[derive(Debug, Clone)] +pub struct DynamicSupervisorOptions { + pub max_children: Option, + pub max_restarts: usize, + pub max_window: Duration, + pub reset_after: Option, +} + +// --------------------------------------------------------------------------- +// Messages +// --------------------------------------------------------------------------- + +pub enum DynamicSupervisorMsg { + SpawnChild { + spec: DynChildSpec, + reply: Option>>, + }, + TerminateChild { + child_id: String, + reply: Option>, + }, + ScheduledRestart { + spec: DynChildSpec, + }, +} + +impl std::fmt::Debug for DynamicSupervisorMsg { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::SpawnChild { spec, .. } => { + f.debug_struct("SpawnChild").field("id", &spec.id).finish() + } + Self::TerminateChild { child_id, .. } => f + .debug_struct("TerminateChild") + .field("child_id", child_id) + .finish(), + Self::ScheduledRestart { spec } => f + .debug_struct("ScheduledRestart") + .field("id", &spec.id) + .finish(), + } + } +} + +// --------------------------------------------------------------------------- +// Internal state +// --------------------------------------------------------------------------- + +struct ActiveChild { + spec: DynChildSpec, + cell: ActorCell, +} + +struct ChildFailureState { + restart_count: usize, + last_fail: Instant, +} + +struct RestartLogEntry { + _child_id: String, + timestamp: Instant, +} + +pub struct DynamicSupervisorState { + options: DynamicSupervisorOptions, + active_children: HashMap, + child_failure_state: HashMap, + restart_log: Vec, +} + +// --------------------------------------------------------------------------- +// Actor +// --------------------------------------------------------------------------- + +pub struct DynamicSupervisor; + +impl DynamicSupervisor { + pub async fn spawn( + name: String, + options: DynamicSupervisorOptions, + ) -> Result<(ActorRef, JoinHandle<()>), SpawnErr> { + Actor::spawn(Some(name), DynamicSupervisor, options).await + } + + pub async fn spawn_linked( + name: impl Into, + handler: T, + args: T::Arguments, + supervisor: ActorCell, + ) -> Result<(ActorRef, JoinHandle<()>), SpawnErr> { + Actor::spawn_linked(Some(name.into()), handler, args, supervisor).await + } + + pub async fn spawn_child( + sup_ref: ActorRef, + spec: DynChildSpec, + ) -> Result<(), ActorProcessingErr> { + ractor::call!(sup_ref, |reply| { + DynamicSupervisorMsg::SpawnChild { + spec, + reply: Some(reply), + } + })? + } + + pub async fn terminate_child( + sup_ref: ActorRef, + child_id: String, + ) -> Result<(), ActorProcessingErr> { + ractor::call!(sup_ref, |reply| { + DynamicSupervisorMsg::TerminateChild { + child_id, + reply: Some(reply), + } + })?; + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Meltdown tracking +// --------------------------------------------------------------------------- + +impl DynamicSupervisorState { + fn track_global_restart(&mut self, child_id: &str) -> Result<(), ActorProcessingErr> { + let now = Instant::now(); + + if let Some(reset_after) = self.options.reset_after { + if let Some(latest) = self.restart_log.last() { + if now.duration_since(latest.timestamp) >= reset_after { + self.restart_log.clear(); + } + } + } + + self.restart_log.push(RestartLogEntry { + _child_id: child_id.to_string(), + timestamp: now, + }); + + self.restart_log + .retain(|e| now.duration_since(e.timestamp) < self.options.max_window); + + if self.restart_log.len() > self.options.max_restarts { + Err(SupervisorError::Meltdown { + reason: "max_restarts exceeded".to_string(), + } + .into()) + } else { + Ok(()) + } + } + + fn prepare_child_failure(&mut self, spec: &DynChildSpec) { + let now = Instant::now(); + let entry = self + .child_failure_state + .entry(spec.id.clone()) + .or_insert(ChildFailureState { + restart_count: 0, + last_fail: now, + }); + + if let Some(threshold) = spec.reset_after { + if now.duration_since(entry.last_fail) >= threshold { + entry.restart_count = 0; + } + } + + entry.restart_count += 1; + entry.last_fail = now; + } +} + +// --------------------------------------------------------------------------- +// Actor implementation +// --------------------------------------------------------------------------- + +#[ractor::async_trait] +impl Actor for DynamicSupervisor { + type Msg = DynamicSupervisorMsg; + type State = DynamicSupervisorState; + type Arguments = DynamicSupervisorOptions; + + async fn pre_start( + &self, + _myself: ActorRef, + options: Self::Arguments, + ) -> Result { + Ok(DynamicSupervisorState { + options, + active_children: HashMap::new(), + child_failure_state: HashMap::new(), + restart_log: Vec::new(), + }) + } + + async fn handle( + &self, + myself: ActorRef, + msg: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match msg { + DynamicSupervisorMsg::SpawnChild { spec, reply } => { + let result = + handle_spawn_child(&spec, reply.is_some(), state, myself.clone()).await; + if let Some(reply) = reply { + reply.send(result)?; + Ok(()) + } else { + result + } + } + DynamicSupervisorMsg::TerminateChild { child_id, reply } => { + handle_terminate_child(&child_id, state, &myself); + if let Some(reply) = reply { + reply.send(())?; + } + Ok(()) + } + DynamicSupervisorMsg::ScheduledRestart { spec } => { + handle_spawn_child(&spec, false, state, myself).await + } + } + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + evt: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match evt { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, reason) => { + handle_child_restart(cell, false, state, &myself, reason.as_deref())?; + } + + SupervisionEvent::ActorFailed(cell, err) => { + let reason = format!("{:?}", err); + handle_child_restart(cell, true, state, &myself, Some(&reason))?; + } + } + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +async fn handle_spawn_child( + spec: &DynChildSpec, + first_start: bool, + state: &mut DynamicSupervisorState, + myself: ActorRef, +) -> Result<(), ActorProcessingErr> { + if !first_start { + state.track_global_restart(&spec.id)?; + tokio::time::sleep(Duration::from_millis(10)).await; + } + + if let Some(max) = state.options.max_children { + if state.active_children.len() >= max { + return Err(SupervisorError::MaxChildrenExceeded.into()); + } + } + + let result = spec.spawn_fn.call(myself.get_cell(), spec.id.clone()).await; + + match result { + Ok(child_cell) => { + state.active_children.insert( + spec.id.clone(), + ActiveChild { + spec: spec.clone(), + cell: child_cell, + }, + ); + state + .child_failure_state + .entry(spec.id.clone()) + .or_insert(ChildFailureState { + restart_count: 0, + last_fail: Instant::now(), + }); + Ok(()) + } + Err(e) => Err(SupervisorError::Meltdown { + reason: format!("spawn failed for '{}': {}", spec.id, e), + } + .into()), + } +} + +fn handle_terminate_child( + child_id: &str, + state: &mut DynamicSupervisorState, + myself: &ActorRef, +) { + if let Some(child) = state.active_children.remove(child_id) { + child.cell.unlink(myself.get_cell()); + child.cell.kill(); + } +} + +fn handle_child_restart( + cell: ActorCell, + abnormal: bool, + state: &mut DynamicSupervisorState, + myself: &ActorRef, + _reason: Option<&str>, +) -> Result<(), ActorProcessingErr> { + let child_id = cell + .get_name() + .ok_or(SupervisorError::ChildNameNotSet { pid: cell.get_id() })?; + + let child = match state.active_children.remove(&child_id) { + Some(c) => c, + None => return Ok(()), + }; + + let should_restart = match child.spec.restart { + crate::RestartPolicy::Permanent => true, + crate::RestartPolicy::Transient => abnormal, + crate::RestartPolicy::Temporary => false, + }; + + if !should_restart { + return Ok(()); + } + + state.prepare_child_failure(&child.spec); + + let delay = child.spec.backoff_fn.as_ref().and_then(|bf| { + let fs = state.child_failure_state.get(&child.spec.id); + let (count, last_fail) = fs + .map(|f| (f.restart_count, f.last_fail)) + .unwrap_or((0, Instant::now())); + bf.call(&child.spec.id, count, last_fail, child.spec.reset_after) + }); + + let spec = child.spec.clone(); + match delay { + Some(d) => { + let dur = ractor::concurrency::Duration::from_millis(d.as_millis() as u64); + myself.send_after(dur, move || DynamicSupervisorMsg::ScheduledRestart { spec }); + } + None => { + myself.send_message(DynamicSupervisorMsg::ScheduledRestart { spec })?; + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::RestartPolicy; + use ractor::{ActorRef, ActorStatus}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; + + static TEST_SEQ: AtomicU64 = AtomicU64::new(0); + + fn unique_name(prefix: &str) -> String { + format!( + "{prefix}_{}_{}", + std::process::id(), + TEST_SEQ.fetch_add(1, Ordering::Relaxed) + ) + } + + #[derive(Clone)] + enum ChildBehavior { + Healthy, + DelayedFail { ms: u64 }, + DelayedNormal { ms: u64 }, + } + + struct TestChild { + counter: Arc, + } + + #[ractor::async_trait] + impl Actor for TestChild { + type Msg = (); + type State = ChildBehavior; + type Arguments = ChildBehavior; + + async fn pre_start( + &self, + myself: ActorRef, + behavior: Self::Arguments, + ) -> Result { + self.counter.fetch_add(1, Ordering::SeqCst); + + match behavior { + ChildBehavior::DelayedFail { ms } | ChildBehavior::DelayedNormal { ms } => { + myself.send_after(Duration::from_millis(ms), || ()); + } + ChildBehavior::Healthy => {} + } + + Ok(behavior) + } + + async fn handle( + &self, + myself: ActorRef, + _msg: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match state { + ChildBehavior::DelayedFail { .. } => panic!("delayed_fail"), + ChildBehavior::DelayedNormal { .. } => myself.stop(None), + ChildBehavior::Healthy => {} + } + Ok(()) + } + } + + fn make_spec( + id: &str, + restart: RestartPolicy, + behavior: ChildBehavior, + counter: Arc, + ) -> DynChildSpec { + let id = id.to_string(); + DynChildSpec { + id: id.clone(), + restart, + spawn_fn: DynSpawnFn::new(move |sup_cell, child_id| { + let behavior = behavior.clone(); + let counter = counter.clone(); + async move { + let (child_ref, _join) = DynamicSupervisor::spawn_linked( + child_id, + TestChild { counter }, + behavior, + sup_cell, + ) + .await?; + Ok(child_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: None, + } + } + + fn options(max_restarts: usize) -> DynamicSupervisorOptions { + DynamicSupervisorOptions { + max_children: None, + max_restarts, + max_window: Duration::from_secs(5), + reset_after: None, + } + } + + #[tokio::test] + async fn transient_child_no_restart_on_normal_exit() { + let sup_name = unique_name("dyn_transient_normal_sup"); + let child_name = unique_name("dyn_transient_normal_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, options(5)) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Transient, + ChildBehavior::DelayedNormal { ms: 50 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(180)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert!( + !sup_ref + .get_children() + .iter() + .any(|c| c.get_status() == ActorStatus::Running) + ); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn temporary_child_never_restarts_on_failure() { + let sup_name = unique_name("dyn_temporary_sup"); + let child_name = unique_name("dyn_temporary_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, options(5)) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Temporary, + ChildBehavior::DelayedFail { ms: 50 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(180)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert!( + !sup_ref + .get_children() + .iter() + .any(|c| c.get_status() == ActorStatus::Running) + ); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn permanent_child_triggers_meltdown_when_budget_exceeded() { + let sup_name = unique_name("dyn_meltdown_sup"); + let child_name = unique_name("dyn_meltdown_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut opts = options(1); + opts.max_window = Duration::from_secs(2); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 40 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn reset_after_allows_restarts_across_quiet_periods() { + let sup_name = unique_name("dyn_reset_after_sup"); + let child_name = unique_name("dyn_reset_after_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut opts = options(1); + opts.max_window = Duration::from_secs(10); + opts.reset_after = Some(Duration::from_millis(80)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 140 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(520)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!(counter.load(Ordering::SeqCst) >= 3); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn max_children_is_enforced() { + let sup_name = unique_name("dyn_max_children_sup"); + let child_name_1 = unique_name("dyn_max_children_child1"); + let child_name_2 = unique_name("dyn_max_children_child2"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut opts = options(5); + opts.max_children = Some(1); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name_1, + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + ), + ) + .await + .expect("first child should spawn"); + + let second = DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name_2, + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + ), + ) + .await; + + assert!(second.is_err()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn terminate_child_does_not_restart() { + let sup_name = unique_name("dyn_terminate_sup"); + let child_name = unique_name("dyn_terminate_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, options(5)) + .await + .expect("failed to spawn dynamic supervisor"); + + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(40)).await; + DynamicSupervisor::terminate_child(sup_ref.clone(), child_name) + .await + .expect("failed to terminate child"); + tokio::time::sleep(Duration::from_millis(120)).await; + + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!( + !sup_ref + .get_children() + .iter() + .any(|c| c.get_status() == ActorStatus::Running) + ); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn backoff_delays_second_restart_attempt() { + let sup_name = unique_name("dyn_backoff_sup"); + let child_name = unique_name("dyn_backoff_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut spec = make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 10 }, + counter.clone(), + ); + spec.backoff_fn = Some(ChildBackoffFn::new( + |_id, restart_count, _last, _child_reset| { + if restart_count <= 1 { + None + } else { + Some(Duration::from_millis(220)) + } + }, + )); + + let mut opts = options(1); + opts.max_window = Duration::from_secs(5); + + let start = std::time::Instant::now(); + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child(sup_ref.clone(), spec) + .await + .expect("failed to spawn child"); + + let _ = sup_handle.await; + let elapsed = start.elapsed(); + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + assert!( + elapsed >= Duration::from_millis(200), + "expected delayed restart, got {elapsed:?}" + ); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } +} diff --git a/crates/supervisor/src/lib.rs b/crates/supervisor/src/lib.rs new file mode 100644 index 0000000000..d145303e7e --- /dev/null +++ b/crates/supervisor/src/lib.rs @@ -0,0 +1,10 @@ +pub mod dynamic; +mod restart; +mod retry; +mod supervisor; + +pub use restart::{RestartBudget, RestartTracker}; +pub use retry::{RetryStrategy, spawn_with_retry}; +pub use supervisor::{ + ChildSpec, RestartPolicy, SpawnFn, Supervisor, SupervisorConfig, SupervisorMsg, +}; diff --git a/crates/supervisor/src/restart.rs b/crates/supervisor/src/restart.rs new file mode 100644 index 0000000000..0d28c16679 --- /dev/null +++ b/crates/supervisor/src/restart.rs @@ -0,0 +1,164 @@ +use std::time::{Duration, Instant}; + +#[derive(Debug, Clone)] +pub struct RestartBudget { + pub max_restarts: u32, + pub max_window: Duration, + pub reset_after: Option, +} + +impl Default for RestartBudget { + fn default() -> Self { + Self { + max_restarts: 3, + max_window: Duration::from_secs(15), + reset_after: Some(Duration::from_secs(30)), + } + } +} + +pub struct RestartTracker { + count: u32, + window_start: Instant, +} + +impl Default for RestartTracker { + fn default() -> Self { + Self::new() + } +} + +impl RestartTracker { + pub fn new() -> Self { + Self { + count: 0, + window_start: Instant::now(), + } + } + + /// Records a restart and returns whether the budget still allows it. + /// `true` = within budget, `false` = meltdown threshold exceeded. + pub fn record_restart(&mut self, budget: &RestartBudget) -> bool { + let now = Instant::now(); + if now.duration_since(self.window_start) > budget.max_window { + self.count = 0; + self.window_start = now; + } + self.count += 1; + self.count <= budget.max_restarts + } + + /// Resets the counter if quiet for longer than `budget.reset_after`. + pub fn maybe_reset(&mut self, budget: &RestartBudget) { + if let Some(reset_after) = budget.reset_after { + let now = Instant::now(); + if now.duration_since(self.window_start) > reset_after { + self.count = 0; + self.window_start = now; + } + } + } + + pub fn count(&self) -> u32 { + self.count + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn budget(max_restarts: u32, max_window_ms: u64) -> RestartBudget { + RestartBudget { + max_restarts, + max_window: Duration::from_millis(max_window_ms), + reset_after: None, + } + } + + #[test] + fn new_tracker_starts_at_zero() { + let tracker = RestartTracker::new(); + assert_eq!(tracker.count(), 0); + } + + #[test] + fn record_restart_increments_within_budget() { + let mut tracker = RestartTracker::new(); + let b = budget(3, 10_000); + assert!(tracker.record_restart(&b)); + assert!(tracker.record_restart(&b)); + assert!(tracker.record_restart(&b)); + assert_eq!(tracker.count(), 3); + } + + #[test] + fn record_restart_exceeds_budget() { + let mut tracker = RestartTracker::new(); + let b = budget(2, 10_000); + assert!(tracker.record_restart(&b)); + assert!(tracker.record_restart(&b)); + assert!(!tracker.record_restart(&b)); + } + + #[test] + fn record_restart_at_exact_boundary() { + let mut tracker = RestartTracker::new(); + let b = budget(1, 10_000); + assert!(tracker.record_restart(&b)); + assert!(!tracker.record_restart(&b)); + } + + #[test] + fn window_expiry_resets_counter() { + let mut tracker = RestartTracker::new(); + let b = budget(1, 50); + assert!(tracker.record_restart(&b)); + std::thread::sleep(Duration::from_millis(100)); + assert!(tracker.record_restart(&b)); + } + + #[test] + fn maybe_reset_clears_after_quiet_period() { + let mut tracker = RestartTracker::new(); + let mut b = budget(3, 10_000); + b.reset_after = Some(Duration::from_millis(50)); + + tracker.record_restart(&b); + tracker.record_restart(&b); + assert_eq!(tracker.count(), 2); + + std::thread::sleep(Duration::from_millis(100)); + tracker.maybe_reset(&b); + assert_eq!(tracker.count(), 0); + } + + #[test] + fn maybe_reset_noop_before_threshold() { + let mut tracker = RestartTracker::new(); + let mut b = budget(3, 10_000); + b.reset_after = Some(Duration::from_secs(60)); + + tracker.record_restart(&b); + tracker.record_restart(&b); + tracker.maybe_reset(&b); + assert_eq!(tracker.count(), 2); + } + + #[test] + fn maybe_reset_noop_when_none() { + let mut tracker = RestartTracker::new(); + let b = budget(3, 10_000); + + tracker.record_restart(&b); + tracker.maybe_reset(&b); + assert_eq!(tracker.count(), 1); + } + + #[test] + fn zero_budget_always_exceeds() { + let mut tracker = RestartTracker::new(); + let b = budget(0, 10_000); + assert!(!tracker.record_restart(&b)); + } +} diff --git a/crates/supervisor/src/retry.rs b/crates/supervisor/src/retry.rs new file mode 100644 index 0000000000..2f576b49c8 --- /dev/null +++ b/crates/supervisor/src/retry.rs @@ -0,0 +1,187 @@ +use ractor::{ActorCell, SpawnErr}; +use std::future::Future; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct RetryStrategy { + pub max_attempts: u32, + pub base_delay: Duration, +} + +impl Default for RetryStrategy { + fn default() -> Self { + Self { + max_attempts: 3, + base_delay: Duration::from_millis(100), + } + } +} + +/// Spawn an actor with exponential-backoff retries. +/// Returns `Some(cell)` on first success, `None` if all attempts fail. +pub async fn spawn_with_retry(strategy: &RetryStrategy, spawn_fn: F) -> Option +where + F: Fn() -> Fut, + Fut: Future>, +{ + for attempt in 0..strategy.max_attempts { + let delay = strategy.base_delay * 2u32.pow(attempt); + tokio::time::sleep(delay).await; + + match spawn_fn().await { + Ok(cell) => { + tracing::info!(attempt, "spawn_retry_succeeded"); + return Some(cell); + } + Err(e) => { + tracing::warn!(attempt, error = ?e, "spawn_retry_failed"); + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use ractor::{Actor, ActorProcessingErr, ActorRef}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::Instant; + + static TEST_SEQ: AtomicU64 = AtomicU64::new(0); + + fn unique_name(prefix: &str) -> String { + format!( + "{prefix}_{}_{}", + std::process::id(), + TEST_SEQ.fetch_add(1, Ordering::Relaxed) + ) + } + + struct DummyActor; + + #[ractor::async_trait] + impl Actor for DummyActor { + type Msg = (); + type State = (); + type Arguments = (); + + async fn pre_start( + &self, + _myself: ActorRef, + _args: Self::Arguments, + ) -> Result { + Ok(()) + } + } + + async fn spawn_name_collision_err(name: &str) -> SpawnErr { + Actor::spawn(Some(name.to_string()), DummyActor, ()) + .await + .expect_err("expected name collision to produce SpawnErr") + } + + #[tokio::test] + async fn zero_attempts_never_calls_spawn() { + let strategy = RetryStrategy { + max_attempts: 0, + base_delay: Duration::from_millis(5), + }; + let calls = Arc::new(AtomicU64::new(0)); + let calls2 = calls.clone(); + + let result = spawn_with_retry(&strategy, move || { + calls2.fetch_add(1, Ordering::SeqCst); + async { + panic!("spawn closure must not be called when max_attempts is zero"); + #[allow(unreachable_code)] + Err(spawn_name_collision_err("unused").await) + } + }) + .await; + + assert!(result.is_none()); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn retries_until_success() { + let keep_name = unique_name("retry_keep"); + let (keep_ref, keep_handle) = Actor::spawn(Some(keep_name.clone()), DummyActor, ()) + .await + .expect("failed to spawn keeper actor"); + + let success_name = unique_name("retry_success"); + let attempts = Arc::new(AtomicU64::new(0)); + let attempts2 = attempts.clone(); + + let strategy = RetryStrategy { + max_attempts: 4, + base_delay: Duration::from_millis(10), + }; + + let start = Instant::now(); + let result = spawn_with_retry(&strategy, move || { + let keep_name = keep_name.clone(); + let success_name = success_name.clone(); + let attempts = attempts2.clone(); + async move { + let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1; + if attempt < 3 { + Err(spawn_name_collision_err(&keep_name).await) + } else { + let (ok_ref, _ok_handle) = + Actor::spawn(Some(success_name), DummyActor, ()).await?; + Ok(ok_ref.get_cell()) + } + } + }) + .await; + + let elapsed = start.elapsed(); + assert!(result.is_some()); + assert_eq!(attempts.load(Ordering::SeqCst), 3); + assert!( + elapsed >= Duration::from_millis(60), + "expected backoff delays before success, got {elapsed:?}" + ); + + if let Some(cell) = result { + cell.kill(); + } + keep_ref.stop(None); + let _ = keep_handle.await; + } + + #[tokio::test] + async fn returns_none_after_exhausting_attempts() { + let keep_name = unique_name("retry_keep_fail"); + let (keep_ref, keep_handle) = Actor::spawn(Some(keep_name.clone()), DummyActor, ()) + .await + .expect("failed to spawn keeper actor"); + + let attempts = Arc::new(AtomicU64::new(0)); + let attempts2 = attempts.clone(); + let strategy = RetryStrategy { + max_attempts: 3, + base_delay: Duration::from_millis(10), + }; + + let result = spawn_with_retry(&strategy, move || { + let keep_name = keep_name.clone(); + let attempts = attempts2.clone(); + async move { + attempts.fetch_add(1, Ordering::SeqCst); + Err(spawn_name_collision_err(&keep_name).await) + } + }) + .await; + + assert!(result.is_none()); + assert_eq!(attempts.load(Ordering::SeqCst), 3); + + keep_ref.stop(None); + let _ = keep_handle.await; + } +} diff --git a/crates/supervisor/src/supervisor.rs b/crates/supervisor/src/supervisor.rs new file mode 100644 index 0000000000..9409562d75 --- /dev/null +++ b/crates/supervisor/src/supervisor.rs @@ -0,0 +1,615 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SpawnErr, SupervisionEvent}; + +use crate::restart::{RestartBudget, RestartTracker}; +use crate::retry::{RetryStrategy, spawn_with_retry}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RestartPolicy { + Permanent, + Transient, + Temporary, +} + +pub type SpawnFuture = Pin> + Send>>; + +#[derive(Clone)] +pub struct SpawnFn(Arc SpawnFuture + Send + Sync>); + +impl SpawnFn { + pub fn new(f: F) -> Self + where + F: Fn(ActorCell) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + Self(Arc::new(move |cell| Box::pin(f(cell)))) + } + + pub async fn call(&self, supervisor_cell: ActorCell) -> Result { + (self.0)(supervisor_cell).await + } +} + +pub struct ChildSpec { + pub id: String, + pub restart_policy: RestartPolicy, + pub spawn_fn: SpawnFn, +} + +pub struct SupervisorConfig { + pub children: Vec, + pub restart_budget: RestartBudget, + pub retry_strategy: RetryStrategy, +} + +struct ChildEntry { + id: String, + cell: Option, + policy: RestartPolicy, + spawn_fn: SpawnFn, + tracker: RestartTracker, +} + +pub struct SupervisorState { + children: Vec, + budget: RestartBudget, + retry_strategy: RetryStrategy, + shutting_down: bool, +} + +pub struct Supervisor; + +#[derive(Debug)] +pub enum SupervisorMsg { + Shutdown, +} + +impl SupervisorState { + fn find_child_index(&self, cell: &ActorCell) -> Option { + self.children + .iter() + .position(|e| e.cell.as_ref().is_some_and(|c| c.get_id() == cell.get_id())) + } + + fn stop_all_children(&mut self) { + for entry in &mut self.children { + if let Some(cell) = entry.cell.take() { + cell.stop(Some("supervisor_shutdown".to_string())); + } + } + } +} + +fn should_restart(policy: RestartPolicy, abnormal: bool) -> bool { + match policy { + RestartPolicy::Permanent => true, + RestartPolicy::Transient => abnormal, + RestartPolicy::Temporary => false, + } +} + +async fn handle_child_exit( + myself: &ActorRef, + state: &mut SupervisorState, + idx: usize, + abnormal: bool, +) { + state.children[idx].cell = None; + + if !should_restart(state.children[idx].policy, abnormal) { + return; + } + + if !state.children[idx].tracker.record_restart(&state.budget) { + tracing::error!(child = %state.children[idx].id, "restart_limit_exceeded"); + state.shutting_down = true; + state.stop_all_children(); + myself.stop(Some("meltdown".to_string())); + return; + } + + let spawn_fn = state.children[idx].spawn_fn.clone(); + let sup_cell = myself.get_cell(); + let retry = state.retry_strategy.clone(); + + let new_cell = spawn_with_retry(&retry, || { + let sup = sup_cell.clone(); + let sf = spawn_fn.clone(); + async move { sf.call(sup).await } + }) + .await; + + match new_cell { + Some(cell) => { + state.children[idx].cell = Some(cell); + } + None => { + tracing::error!(child = %state.children[idx].id, "spawn_retry_exhausted"); + state.shutting_down = true; + state.stop_all_children(); + myself.stop(Some("spawn_retry_exhausted".to_string())); + } + } +} + +#[ractor::async_trait] +impl Actor for Supervisor { + type Msg = SupervisorMsg; + type State = SupervisorState; + type Arguments = SupervisorConfig; + + async fn pre_start( + &self, + myself: ActorRef, + config: Self::Arguments, + ) -> Result { + let mut children = Vec::new(); + + for spec in config.children { + let cell = + spec.spawn_fn + .call(myself.get_cell()) + .await + .map_err(|e| -> ActorProcessingErr { + format!("failed to spawn child '{}': {}", spec.id, e).into() + })?; + + children.push(ChildEntry { + id: spec.id, + cell: Some(cell), + policy: spec.restart_policy, + spawn_fn: spec.spawn_fn, + tracker: RestartTracker::new(), + }); + } + + Ok(SupervisorState { + children, + budget: config.restart_budget, + retry_strategy: config.retry_strategy, + shutting_down: false, + }) + } + + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + SupervisorMsg::Shutdown => { + state.shutting_down = true; + state.stop_all_children(); + myself.stop(None); + } + } + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + message: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + for child in &mut state.children { + child.tracker.maybe_reset(&state.budget); + } + + if state.shutting_down { + return Ok(()); + } + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, _reason) => { + if let Some(idx) = state.find_child_index(&cell) { + handle_child_exit(&myself, state, idx, false).await; + } + } + + SupervisionEvent::ActorFailed(cell, _error) => { + if let Some(idx) = state.find_child_index(&cell) { + handle_child_exit(&myself, state, idx, true).await; + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ractor::{Actor, ActorRef, ActorStatus}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::time::Duration; + + // ---- Test child actor with configurable behaviors ---- + + #[derive(Clone)] + enum ChildBehavior { + Healthy, + DelayedFail { ms: u64 }, + DelayedNormal { ms: u64 }, + } + + struct TestChild { + counter: Arc, + } + + #[ractor::async_trait] + impl Actor for TestChild { + type Msg = (); + type State = ChildBehavior; + type Arguments = ChildBehavior; + + async fn pre_start( + &self, + myself: ActorRef, + behavior: Self::Arguments, + ) -> Result { + self.counter.fetch_add(1, Ordering::SeqCst); + + match &behavior { + ChildBehavior::Healthy => {} + ChildBehavior::DelayedFail { ms } => { + myself.send_after(Duration::from_millis(*ms), || ()); + } + ChildBehavior::DelayedNormal { ms } => { + myself.send_after(Duration::from_millis(*ms), || ()); + } + } + Ok(behavior) + } + + async fn handle( + &self, + myself: ActorRef, + _msg: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match state { + ChildBehavior::DelayedFail { .. } => panic!("delayed_fail"), + ChildBehavior::DelayedNormal { .. } => myself.stop(None), + _ => {} + } + Ok(()) + } + } + + // ---- Helpers ---- + + fn make_child_spec( + name: &str, + policy: RestartPolicy, + behavior: ChildBehavior, + counter: Arc, + ) -> ChildSpec { + let name = name.to_string(); + ChildSpec { + id: name.clone(), + restart_policy: policy, + spawn_fn: SpawnFn::new(move |sup_cell| { + let behavior = behavior.clone(); + let counter = counter.clone(); + let name = name.clone(); + async move { + let (actor_ref, _) = + Actor::spawn_linked(Some(name), TestChild { counter }, behavior, sup_cell) + .await?; + Ok(actor_ref.get_cell()) + } + }), + } + } + + fn test_budget(max_restarts: u32) -> RestartBudget { + RestartBudget { + max_restarts, + max_window: Duration::from_secs(10), + reset_after: None, + } + } + + fn fast_retry() -> RetryStrategy { + RetryStrategy { + max_attempts: 3, + base_delay: Duration::from_millis(20), + } + } + + // ---- Tests ---- + + #[tokio::test] + async fn permanent_child_restarts_on_failure() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "perm_restart_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(1), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_perm_restart".to_string()), Supervisor, config) + .await + .unwrap(); + + // Child fails once -> restarted -> fails again -> meltdown + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + // initial + 1 restart = 2 spawns + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn transient_no_restart_on_normal_exit() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "trans_normal_child", + RestartPolicy::Transient, + ChildBehavior::DelayedNormal { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_trans_normal".to_string()), Supervisor, config) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn transient_restarts_on_failure() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "trans_fail_child", + RestartPolicy::Transient, + ChildBehavior::DelayedFail { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(1), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_trans_fail".to_string()), Supervisor, config) + .await + .unwrap(); + + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn temporary_never_restarts() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "temp_child", + RestartPolicy::Temporary, + ChildBehavior::DelayedFail { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_temp_never".to_string()), Supervisor, config) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn meltdown_on_budget_exceeded() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "meltdown_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 50 }, + counter.clone(), + )], + restart_budget: test_budget(2), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_meltdown".to_string()), Supervisor, config) + .await + .unwrap(); + + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + // initial + 2 restarts = 3 spawns, then meltdown on 3rd failure + assert_eq!(counter.load(Ordering::SeqCst), 3); + } + + #[tokio::test] + async fn reset_after_quiet_period() { + let counter = Arc::new(AtomicU32::new(0)); + + // Budget: 1 restart in 10s window, but reset_after 200ms + let budget = RestartBudget { + max_restarts: 1, + max_window: Duration::from_secs(10), + reset_after: Some(Duration::from_millis(200)), + }; + + let config = SupervisorConfig { + children: vec![make_child_spec( + "reset_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 300 }, + counter.clone(), + )], + restart_budget: budget, + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_reset_quiet".to_string()), Supervisor, config) + .await + .unwrap(); + + // Child fails at ~300ms, restarted (count=1). + // Fails again at ~600ms. reset_after=200ms < 300ms gap => counter resets => count=1 again, no meltdown. + // Let it run through a few cycles. + tokio::time::sleep(Duration::from_millis(1500)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!(counter.load(Ordering::SeqCst) >= 3); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn shutdown_suppresses_restarts() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "shutdown_child", + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + )], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = Actor::spawn( + Some("test_shutdown_suppress".to_string()), + Supervisor, + config, + ) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + + let _ = sup_ref.cast(SupervisorMsg::Shutdown); + let _ = sup_handle.await; + + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + // Only spawned once; shutdown should not trigger restarts + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn multiple_children_independent() { + let healthy_counter = Arc::new(AtomicU32::new(0)); + let failing_counter = Arc::new(AtomicU32::new(0)); + + let config = SupervisorConfig { + children: vec![ + make_child_spec( + "multi_healthy", + RestartPolicy::Permanent, + ChildBehavior::Healthy, + healthy_counter.clone(), + ), + make_child_spec( + "multi_temp_fail", + RestartPolicy::Temporary, + ChildBehavior::DelayedFail { ms: 100 }, + failing_counter.clone(), + ), + ], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_multi_indep".to_string()), Supervisor, config) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + + // Healthy child spawned once and still running + assert_eq!(healthy_counter.load(Ordering::SeqCst), 1); + // Temporary child failed once, not restarted + assert_eq!(failing_counter.load(Ordering::SeqCst), 1); + + // Supervisor still has the healthy child + let running: Vec<_> = sup_ref + .get_children() + .into_iter() + .filter(|c| c.get_status() == ActorStatus::Running) + .collect(); + assert_eq!(running.len(), 1); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn window_expiry_allows_more_restarts() { + let counter = Arc::new(AtomicU32::new(0)); + + // Budget: 1 restart in a 200ms window (no reset_after) + let budget = RestartBudget { + max_restarts: 1, + max_window: Duration::from_millis(200), + reset_after: None, + }; + + let config = SupervisorConfig { + children: vec![make_child_spec( + "window_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 300 }, + counter.clone(), + )], + restart_budget: budget, + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_window_expiry".to_string()), Supervisor, config) + .await + .unwrap(); + + // Each child lives ~300ms, then fails. Window is 200ms, so each failure starts a new window. + // With budget=1 per window, each failure is restart #1 in a fresh window => no meltdown. + tokio::time::sleep(Duration::from_millis(1200)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!(counter.load(Ordering::SeqCst) >= 3); + + sup_ref.stop(None); + let _ = sup_handle.await; + } +} diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 12c7566338..6dd3e9f5d9 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -63,8 +63,8 @@ uuid = { workspace = true, features = ["v4"] } hound = { workspace = true } vorbis_rs = { workspace = true } -ractor = { workspace = true } -ractor-supervisor = { workspace = true } +hypr-supervisor = { workspace = true } +ractor = { workspace = true, features = ["async-trait"] } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index a37b6d3364..a2bae4c8c5 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -109,9 +109,10 @@ sessionProgressEvent: "plugin:listener:session-progress-event" /** user-defined types **/ +export type DegradedError = { type: "authentication_failed"; provider: string } | { type: "upstream_unavailable"; message: string } | { type: "connection_timeout" } | { type: "stream_error"; message: string } | { type: "channel_overflow" } export type SessionDataEvent = { type: "audio_amplitude"; session_id: string; mic: number; speaker: number } | { type: "mic_muted"; session_id: string; value: boolean } | { type: "stream_response"; session_id: string; response: StreamResponse } export type SessionErrorEvent = { type: "audio_error"; session_id: string; error: string; device: string | null; is_fatal: boolean } | { type: "connection_error"; session_id: string; error: string } -export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string } | { type: "finalizing"; session_id: string } +export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string; error?: DegradedError | null } | { type: "finalizing"; session_id: string } export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } export type SessionProgressEvent = { type: "audio_initializing"; session_id: string } | { type: "audio_ready"; session_id: string; device: string | null } | { type: "connecting"; session_id: string } | { type: "connected"; session_id: string; adapter: string } export type State = "active" | "inactive" | "finalizing" diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs index b08b58c961..81819bf1a5 100644 --- a/plugins/listener/src/actors/root.rs +++ b/plugins/listener/src/actors/root.rs @@ -8,9 +8,10 @@ use tracing::Instrument; use crate::SessionLifecycleEvent; use crate::actors::session::lifecycle::{ clear_sentry_session_context, configure_sentry_session_context, emit_session_ended, - stop_actor_by_name_and_wait, }; -use crate::actors::{SessionContext, SessionParams, session_span, spawn_session_supervisor}; +use crate::actors::{ + SessionContext, SessionMsg, SessionParams, session_span, spawn_session_supervisor, +}; pub enum RootMsg { StartSession(SessionParams, RpcReplyPort), @@ -103,7 +104,8 @@ impl Actor for RootActor { tracing::info!(?reason, "session_supervisor_terminated"); state.supervisor = None; state.finalizing = false; - emit_session_ended(&state.app, &session_id, None); + + emit_session_ended(&state.app, &session_id, reason); } } SupervisionEvent::ActorFailed(cell, error) => { @@ -171,6 +173,7 @@ async fn start_session_impl( if let Err(error) = (SessionLifecycleEvent::Active { session_id: params.session_id, + error: None, }) .emit(&state.app) { @@ -212,9 +215,13 @@ async fn stop_session_impl(state: &mut RootState) { } } - // TO make sure post_stop is called. - stop_actor_by_name_and_wait(crate::actors::RecorderActor::name(), "session_stop").await; - - supervisor.stop(None); + let session_ref: ActorRef = supervisor.clone().into(); + if let Err(error) = session_ref.cast(SessionMsg::Shutdown) { + tracing::warn!( + ?error, + "failed_to_cast_session_shutdown_falling_back_to_stop" + ); + supervisor.stop(Some("session_stop_cast_failed".to_string())); + } } } diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs index bd9419d86e..0610875edc 100644 --- a/plugins/listener/src/actors/session/mod.rs +++ b/plugins/listener/src/actors/session/mod.rs @@ -1,169 +1,6 @@ pub(crate) mod lifecycle; +mod supervisor; +mod types; -use std::path::PathBuf; -use std::time::{Instant, SystemTime}; - -use ractor::concurrency::Duration; -use ractor::{Actor, ActorCell, ActorProcessingErr}; -use ractor_supervisor::SupervisorStrategy; -use ractor_supervisor::core::{ChildBackoffFn, ChildSpec, Restart, SpawnFn}; -use ractor_supervisor::supervisor::{Supervisor, SupervisorArguments, SupervisorOptions}; - -use crate::actors::{ - ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, -}; - -pub const SESSION_SUPERVISOR_PREFIX: &str = "session_supervisor_"; - -/// Creates a tracing span with session context that child events will inherit -pub(crate) fn session_span(session_id: &str) -> tracing::Span { - tracing::info_span!("session", session_id = %session_id) -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] -pub struct SessionParams { - pub session_id: String, - pub languages: Vec, - pub onboarding: bool, - pub record_enabled: bool, - pub model: String, - pub base_url: String, - pub api_key: String, - pub keywords: Vec, -} - -#[derive(Clone)] -pub struct SessionContext { - pub app: tauri::AppHandle, - pub params: SessionParams, - pub app_dir: PathBuf, - pub started_at_instant: Instant, - pub started_at_system: SystemTime, -} - -pub fn session_supervisor_name(session_id: &str) -> String { - format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) -} - -fn make_supervisor_options() -> SupervisorOptions { - SupervisorOptions { - strategy: SupervisorStrategy::RestForOne, - max_restarts: 3, - max_window: Duration::from_secs(15), - reset_after: Some(Duration::from_secs(30)), - } -} - -fn make_listener_backoff() -> ChildBackoffFn { - ChildBackoffFn::new(|_id, count, _, _| { - if count == 0 { - None - } else { - Some(Duration::from_millis(500)) - } - }) -} - -pub async fn spawn_session_supervisor( - ctx: SessionContext, -) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { - let supervisor_name = session_supervisor_name(&ctx.params.session_id); - - let mut child_specs = Vec::new(); - - let ctx_source = ctx.clone(); - child_specs.push(ChildSpec { - id: SourceActor::name().to_string(), - restart: Restart::Permanent, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_source.clone(); - async move { - let (actor_ref, _) = Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - mic_device: None, - onboarding: ctx.params.onboarding, - app: ctx.app.clone(), - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: None, - reset_after: Some(Duration::from_secs(30)), - }); - - let ctx_listener = ctx.clone(); - child_specs.push(ChildSpec { - id: ListenerActor::name().to_string(), - restart: Restart::Permanent, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_listener.clone(); - async move { - let mode = ChannelMode::determine(ctx.params.onboarding); - - let (actor_ref, _) = Actor::spawn_linked( - Some(ListenerActor::name()), - ListenerActor, - ListenerArgs { - app: ctx.app.clone(), - languages: ctx.params.languages.clone(), - onboarding: ctx.params.onboarding, - model: ctx.params.model.clone(), - base_url: ctx.params.base_url.clone(), - api_key: ctx.params.api_key.clone(), - keywords: ctx.params.keywords.clone(), - mode, - session_started_at: ctx.started_at_instant, - session_started_at_unix: ctx.started_at_system, - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: Some(make_listener_backoff()), - reset_after: Some(Duration::from_secs(30)), - }); - - if ctx.params.record_enabled { - let ctx_recorder = ctx.clone(); - child_specs.push(ChildSpec { - id: RecorderActor::name().to_string(), - restart: Restart::Transient, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_recorder.clone(); - async move { - let (actor_ref, _) = Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: ctx.app_dir.clone(), - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: None, - reset_after: None, - }); - } - - let args = SupervisorArguments { - child_specs, - options: make_supervisor_options(), - }; - - let (supervisor_ref, handle) = Supervisor::spawn(supervisor_name, args).await?; - - Ok((supervisor_ref.get_cell(), handle)) -} +pub use supervisor::*; +pub use types::*; diff --git a/plugins/listener/src/actors/session/supervisor.rs b/plugins/listener/src/actors/session/supervisor.rs new file mode 100644 index 0000000000..ca55e18d39 --- /dev/null +++ b/plugins/listener/src/actors/session/supervisor.rs @@ -0,0 +1,419 @@ +use hypr_supervisor::{RestartBudget, RestartTracker, RetryStrategy, spawn_with_retry}; +use ractor::concurrency::Duration; +use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SupervisionEvent}; +use tauri_specta::Event; +use tracing::Instrument; + +use crate::DegradedError; +use crate::SessionLifecycleEvent; +use crate::actors::session::lifecycle; +use crate::actors::session::types::{SessionContext, session_span, session_supervisor_name}; +use crate::actors::{ + ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ChildKind { + Source, + Listener, + Recorder, +} + +const RESTART_BUDGET: RestartBudget = RestartBudget { + max_restarts: 3, + max_window: Duration::from_secs(15), + reset_after: Some(Duration::from_secs(30)), +}; + +const RETRY_STRATEGY: RetryStrategy = RetryStrategy { + max_attempts: 3, + base_delay: Duration::from_millis(100), +}; + +pub struct SessionState { + ctx: SessionContext, + source_cell: Option, + listener_cell: Option, + recorder_cell: Option, + listener_degraded: Option, + source_restarts: RestartTracker, + recorder_restarts: RestartTracker, + shutting_down: bool, +} + +pub struct SessionActor; + +#[derive(Debug)] +pub enum SessionMsg { + Shutdown, +} + +#[ractor::async_trait] +impl Actor for SessionActor { + type Msg = SessionMsg; + type State = SessionState; + type Arguments = SessionContext; + + async fn pre_start( + &self, + myself: ActorRef, + ctx: Self::Arguments, + ) -> Result { + let session_id = ctx.params.session_id.clone(); + let span = session_span(&session_id); + + async { + let (source_ref, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: ctx.params.onboarding, + app: ctx.app.clone(), + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + + let mode = ChannelMode::determine(ctx.params.onboarding); + let (listener_ref, _) = Actor::spawn_linked( + Some(ListenerActor::name()), + ListenerActor, + ListenerArgs { + app: ctx.app.clone(), + languages: ctx.params.languages.clone(), + onboarding: ctx.params.onboarding, + model: ctx.params.model.clone(), + base_url: ctx.params.base_url.clone(), + api_key: ctx.params.api_key.clone(), + keywords: ctx.params.keywords.clone(), + mode, + session_started_at: ctx.started_at_instant, + session_started_at_unix: ctx.started_at_system, + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + + let recorder_cell = if ctx.params.record_enabled { + let (recorder_ref, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: ctx.app_dir.clone(), + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + Some(recorder_ref.get_cell()) + } else { + None + }; + + Ok(SessionState { + ctx, + source_cell: Some(source_ref.get_cell()), + listener_cell: Some(listener_ref.get_cell()), + recorder_cell, + listener_degraded: None, + source_restarts: RestartTracker::new(), + recorder_restarts: RestartTracker::new(), + shutting_down: false, + }) + } + .instrument(span) + .await + } + + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + SessionMsg::Shutdown => { + state.shutting_down = true; + + if let Some(cell) = state.recorder_cell.take() { + cell.stop(Some("session_stop".to_string())); + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + } + + if let Some(cell) = state.source_cell.take() { + cell.stop(Some("session_stop".to_string())); + } + if let Some(cell) = state.listener_cell.take() { + cell.stop(Some("session_stop".to_string())); + } + + myself.stop(None); + } + } + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + message: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + let span = session_span(&state.ctx.params.session_id); + let _guard = span.enter(); + + state.source_restarts.maybe_reset(&RESTART_BUDGET); + state.recorder_restarts.maybe_reset(&RESTART_BUDGET); + + if state.shutting_down { + return Ok(()); + } + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, reason) => { + match identify_child(state, &cell) { + Some(ChildKind::Listener) => { + tracing::info!(?reason, "listener_terminated_entering_degraded_mode"); + let degraded = parse_degraded_reason(reason.as_ref()); + state.listener_degraded = Some(degraded.clone()); + state.listener_cell = None; + + let _ = (SessionLifecycleEvent::Active { + session_id: state.ctx.params.session_id.clone(), + error: Some(degraded), + }) + .emit(&state.ctx.app); + } + Some(ChildKind::Source) => { + tracing::info!(?reason, "source_terminated_attempting_restart"); + state.source_cell = None; + if !try_restart_source(myself.get_cell(), state).await { + tracing::error!("source_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + Some(ChildKind::Recorder) => { + tracing::info!(?reason, "recorder_terminated_attempting_restart"); + state.recorder_cell = None; + if !try_restart_recorder(myself.get_cell(), state).await { + tracing::error!("recorder_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + None => { + tracing::warn!("unknown_child_terminated"); + } + } + } + + SupervisionEvent::ActorFailed(cell, error) => match identify_child(state, &cell) { + Some(ChildKind::Listener) => { + tracing::info!(?error, "listener_failed_entering_degraded_mode"); + let degraded = DegradedError::StreamError { + message: format!("{:?}", error), + }; + state.listener_degraded = Some(degraded.clone()); + state.listener_cell = None; + + let _ = (SessionLifecycleEvent::Active { + session_id: state.ctx.params.session_id.clone(), + error: Some(degraded), + }) + .emit(&state.ctx.app); + } + Some(ChildKind::Source) => { + tracing::warn!(?error, "source_failed_attempting_restart"); + state.source_cell = None; + if !try_restart_source(myself.get_cell(), state).await { + tracing::error!("source_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + Some(ChildKind::Recorder) => { + tracing::warn!(?error, "recorder_failed_attempting_restart"); + state.recorder_cell = None; + if !try_restart_recorder(myself.get_cell(), state).await { + tracing::error!("recorder_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + None => { + tracing::warn!("unknown_child_failed"); + } + }, + } + Ok(()) + } +} + +fn identify_child(state: &SessionState, cell: &ActorCell) -> Option { + if state + .source_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Source); + } + if state + .listener_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Listener); + } + if state + .recorder_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Recorder); + } + None +} + +async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state.source_restarts.record_restart(&RESTART_BUDGET) { + return false; + } + + let sup = supervisor_cell; + let onboarding = state.ctx.params.onboarding; + let app = state.ctx.app.clone(); + let session_id = state.ctx.params.session_id.clone(); + + let cell = spawn_with_retry(&RETRY_STRATEGY, || { + let sup = sup.clone(); + let app = app.clone(); + let session_id = session_id.clone(); + async move { + let (r, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding, + app, + session_id, + }, + sup, + ) + .await?; + Ok(r.get_cell()) + } + }) + .await; + + match cell { + Some(c) => { + state.source_cell = Some(c); + true + } + None => false, + } +} + +async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state.ctx.params.record_enabled { + return true; + } + + if !state.recorder_restarts.record_restart(&RESTART_BUDGET) { + return false; + } + + let sup = supervisor_cell; + let app_dir = state.ctx.app_dir.clone(); + let session_id = state.ctx.params.session_id.clone(); + + let cell = spawn_with_retry(&RETRY_STRATEGY, || { + let sup = sup.clone(); + let app_dir = app_dir.clone(); + let session_id = session_id.clone(); + async move { + let (r, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir, + session_id, + }, + sup, + ) + .await?; + Ok(r.get_cell()) + } + }) + .await; + + match cell { + Some(c) => { + state.recorder_cell = Some(c); + true + } + None => false, + } +} + +async fn meltdown(myself: ActorRef, state: &mut SessionState) { + state.shutting_down = true; + + if let Some(cell) = state.source_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + if let Some(cell) = state.listener_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + if let Some(cell) = state.recorder_cell.take() { + cell.stop(Some("meltdown".to_string())); + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + } + myself.stop(Some("restart_limit_exceeded".to_string())); +} + +fn parse_degraded_reason(reason: Option<&String>) -> DegradedError { + reason + .and_then(|r| serde_json::from_str::(r).ok()) + .unwrap_or_else(|| DegradedError::StreamError { + message: reason + .cloned() + .unwrap_or_else(|| "listener terminated without reason".to_string()), + }) +} + +pub async fn spawn_session_supervisor( + ctx: SessionContext, +) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { + let supervisor_name = session_supervisor_name(&ctx.params.session_id); + let (actor_ref, handle) = Actor::spawn(Some(supervisor_name), SessionActor, ctx).await?; + Ok((actor_ref.get_cell(), handle)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_degraded_reason_uses_json_payload() { + let reason = serde_json::to_string(&DegradedError::ConnectionTimeout).unwrap(); + let parsed = parse_degraded_reason(Some(&reason)); + assert!(matches!(parsed, DegradedError::ConnectionTimeout)); + } + + #[test] + fn parse_degraded_reason_falls_back_for_missing_reason() { + let parsed = parse_degraded_reason(None); + assert!(matches!(parsed, DegradedError::StreamError { .. })); + } + + #[test] + fn parse_degraded_reason_falls_back_for_invalid_json() { + let reason = "not-json".to_string(); + let parsed = parse_degraded_reason(Some(&reason)); + assert!(matches!(parsed, DegradedError::StreamError { .. })); + } +} diff --git a/plugins/listener/src/actors/session/types.rs b/plugins/listener/src/actors/session/types.rs new file mode 100644 index 0000000000..9bf5df0b12 --- /dev/null +++ b/plugins/listener/src/actors/session/types.rs @@ -0,0 +1,33 @@ +use std::path::PathBuf; +use std::time::{Instant, SystemTime}; + +pub const SESSION_SUPERVISOR_PREFIX: &str = "session_supervisor_"; + +pub fn session_span(session_id: &str) -> tracing::Span { + tracing::info_span!("session", session_id = %session_id) +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct SessionParams { + pub session_id: String, + pub languages: Vec, + pub onboarding: bool, + pub record_enabled: bool, + pub model: String, + pub base_url: String, + pub api_key: String, + pub keywords: Vec, +} + +#[derive(Clone)] +pub struct SessionContext { + pub app: tauri::AppHandle, + pub params: SessionParams, + pub app_dir: PathBuf, + pub started_at_instant: Instant, + pub started_at_system: SystemTime, +} + +pub fn session_supervisor_name(session_id: &str) -> String { + format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) +} diff --git a/plugins/listener/src/events.rs b/plugins/listener/src/events.rs index 876a379dc0..a8449f75eb 100644 --- a/plugins/listener/src/events.rs +++ b/plugins/listener/src/events.rs @@ -19,7 +19,11 @@ common_event_derives! { error: Option, }, #[serde(rename = "active")] - Active { session_id: String }, + Active { + session_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, #[serde(rename = "finalizing")] Finalizing { session_id: String }, } diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index 8ffc7888e0..7814ce41a9 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -79,8 +79,8 @@ tokio = { workspace = true, features = ["rt", "macros"] } tokio-util = { workspace = true } tracing = { workspace = true } -ractor = { workspace = true } -ractor-supervisor = { workspace = true } +hypr-supervisor = { workspace = true } +ractor = { workspace = true, features = ["async-trait"] } port-killer = "0.1.0" port_check = "0.3.0" diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 497becdbbf..727830d4c2 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -1,5 +1,5 @@ +use hypr_supervisor::dynamic::DynamicSupervisorMsg; use ractor::{ActorCell, ActorRef}; -use ractor_supervisor::dynamic::DynamicSupervisorMsg; use std::collections::HashMap; use tauri::{Manager, Wry}; use tokio_util::sync::CancellationToken; diff --git a/plugins/local-stt/src/server/supervisor.rs b/plugins/local-stt/src/server/supervisor.rs index 93cb6da32f..b89252aec2 100644 --- a/plugins/local-stt/src/server/supervisor.rs +++ b/plugins/local-stt/src/server/supervisor.rs @@ -1,8 +1,11 @@ -use ractor::{ActorCell, ActorProcessingErr, ActorRef, concurrency::Duration, registry}; -use ractor_supervisor::{ - core::{ChildBackoffFn, ChildSpec, Restart, SpawnFn, SupervisorError}, - dynamic::{DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions}, +use hypr_supervisor::{ + RestartPolicy, + dynamic::{ + ChildBackoffFn, DynChildSpec, DynSpawnFn, DynamicSupervisor, DynamicSupervisorMsg, + DynamicSupervisorOptions, SupervisorError, + }, }; +use ractor::{ActorCell, ActorProcessingErr, ActorRef, concurrency::Duration, registry}; #[cfg(feature = "whisper-cpp")] use super::internal::{InternalSTTActor, InternalSTTArgs}; @@ -59,8 +62,8 @@ pub async fn start_external_stt( } #[cfg(feature = "whisper-cpp")] -fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> ChildSpec { - let spawn_fn = SpawnFn::new(move |supervisor: ActorCell, child_id: String| { +fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> DynChildSpec { + let spawn_fn = DynSpawnFn::new(move |supervisor: ActorCell, child_id: String| { let args = args.clone(); async move { let (actor_ref, _handle) = @@ -70,10 +73,10 @@ fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> ChildSpec { } }); - ChildSpec { + DynChildSpec { id: INTERNAL_STT_ACTOR_NAME.to_string(), spawn_fn, - restart: Restart::Transient, + restart: RestartPolicy::Transient, backoff_fn: Some(ChildBackoffFn::new(|_, _, _, _| { Some(Duration::from_millis(500)) })), @@ -81,8 +84,8 @@ fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> ChildSpec { } } -fn create_external_child_spec_with_args(args: ExternalSTTArgs) -> ChildSpec { - let spawn_fn = SpawnFn::new(move |supervisor: ActorCell, child_id: String| { +fn create_external_child_spec_with_args(args: ExternalSTTArgs) -> DynChildSpec { + let spawn_fn = DynSpawnFn::new(move |supervisor: ActorCell, child_id: String| { let args = args.clone(); async move { let (actor_ref, _handle) = @@ -92,10 +95,10 @@ fn create_external_child_spec_with_args(args: ExternalSTTArgs) -> ChildSpec { } }); - ChildSpec { + DynChildSpec { id: EXTERNAL_STT_ACTOR_NAME.to_string(), spawn_fn, - restart: Restart::Transient, + restart: RestartPolicy::Transient, backoff_fn: Some(ChildBackoffFn::new(|_, _, _, _| { Some(Duration::from_secs(1)) })), diff --git a/plugins/network/Cargo.toml b/plugins/network/Cargo.toml index 74201dfe0b..07108903c0 100644 --- a/plugins/network/Cargo.toml +++ b/plugins/network/Cargo.toml @@ -21,7 +21,6 @@ serde = { workspace = true } specta = { workspace = true } ractor = { workspace = true } -ractor-supervisor = { workspace = true } reqwest = { workspace = true } tracing = { workspace = true }