From fd98d3056202b4bc33ca9841c6181db27a5ab7cf Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 14 Feb 2026 04:35:16 +0000 Subject: [PATCH 1/5] Replace ractor_supervisor::Supervisor with custom SessionActor for per-actor-type supervision Co-Authored-By: yujonglee --- Cargo.lock | 1 - plugins/listener/Cargo.toml | 1 - plugins/listener/src/actors/root.rs | 1 + plugins/listener/src/actors/session/mod.rs | 450 ++++++++++++++++----- plugins/listener/src/events.rs | 6 +- 5 files changed, 348 insertions(+), 111 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6792dc9224..ca0eef1538 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18910,7 +18910,6 @@ dependencies = [ "quickcheck", "quickcheck_macros", "ractor", - "ractor-supervisor", "rodio", "sentry", "serde", diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 12c7566338..6560151f92 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -64,7 +64,6 @@ hound = { workspace = true } vorbis_rs = { workspace = true } ractor = { workspace = true } -ractor-supervisor = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs index b08b58c961..f8cea6e48d 100644 --- a/plugins/listener/src/actors/root.rs +++ b/plugins/listener/src/actors/root.rs @@ -171,6 +171,7 @@ async fn start_session_impl( if let Err(error) = (SessionLifecycleEvent::Active { session_id: params.session_id, + error: None, }) .emit(&state.app) { diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs index bd9419d86e..e89b5ce411 100644 --- a/plugins/listener/src/actors/session/mod.rs +++ b/plugins/listener/src/actors/session/mod.rs @@ -4,11 +4,12 @@ use std::path::PathBuf; use std::time::{Instant, SystemTime}; use ractor::concurrency::Duration; -use ractor::{Actor, ActorCell, ActorProcessingErr}; -use ractor_supervisor::SupervisorStrategy; -use ractor_supervisor::core::{ChildBackoffFn, ChildSpec, Restart, SpawnFn}; -use ractor_supervisor::supervisor::{Supervisor, SupervisorArguments, SupervisorOptions}; +use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SupervisionEvent}; +use tauri_specta::Event; +use tracing::Instrument; +use crate::DegradedError; +use crate::SessionLifecycleEvent; use crate::actors::{ ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, }; @@ -45,125 +46,358 @@ pub fn session_supervisor_name(session_id: &str) -> String { format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) } -fn make_supervisor_options() -> SupervisorOptions { - SupervisorOptions { - strategy: SupervisorStrategy::RestForOne, - max_restarts: 3, - max_window: Duration::from_secs(15), - reset_after: Some(Duration::from_secs(30)), - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ChildKind { + Source, + Listener, + Recorder, +} + +struct RestartTracker { + count: u32, + window_start: Instant, } -fn make_listener_backoff() -> ChildBackoffFn { - ChildBackoffFn::new(|_id, count, _, _| { - if count == 0 { - None - } else { - Some(Duration::from_millis(500)) +impl RestartTracker { + fn new() -> Self { + Self { + count: 0, + window_start: Instant::now(), + } + } + + fn record_restart(&mut self, max_restarts: u32, max_window: Duration) -> bool { + let now = Instant::now(); + if now.duration_since(self.window_start) > max_window { + self.count = 0; + self.window_start = now; } - }) + self.count += 1; + self.count <= max_restarts + } + + fn maybe_reset(&mut self, reset_after: Duration) { + let now = Instant::now(); + if now.duration_since(self.window_start) > reset_after { + self.count = 0; + self.window_start = now; + } + } } -pub async fn spawn_session_supervisor( +pub struct SessionState { ctx: SessionContext, -) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { - let supervisor_name = session_supervisor_name(&ctx.params.session_id); + source_cell: Option, + listener_cell: Option, + recorder_cell: Option, + listener_degraded: Option, + source_restarts: RestartTracker, + recorder_restarts: RestartTracker, +} - let mut child_specs = Vec::new(); - - let ctx_source = ctx.clone(); - child_specs.push(ChildSpec { - id: SourceActor::name().to_string(), - restart: Restart::Permanent, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_source.clone(); - async move { - let (actor_ref, _) = Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - mic_device: None, - onboarding: ctx.params.onboarding, - app: ctx.app.clone(), - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) - } - }), - backoff_fn: None, - reset_after: Some(Duration::from_secs(30)), - }); - - let ctx_listener = ctx.clone(); - child_specs.push(ChildSpec { - id: ListenerActor::name().to_string(), - restart: Restart::Permanent, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_listener.clone(); - async move { - let mode = ChannelMode::determine(ctx.params.onboarding); - - let (actor_ref, _) = Actor::spawn_linked( - Some(ListenerActor::name()), - ListenerActor, - ListenerArgs { - app: ctx.app.clone(), - languages: ctx.params.languages.clone(), - onboarding: ctx.params.onboarding, - model: ctx.params.model.clone(), - base_url: ctx.params.base_url.clone(), - api_key: ctx.params.api_key.clone(), - keywords: ctx.params.keywords.clone(), - mode, - session_started_at: ctx.started_at_instant, - session_started_at_unix: ctx.started_at_system, +pub struct SessionActor; + +impl SessionActor { + const MAX_RESTARTS: u32 = 3; + const MAX_WINDOW: Duration = Duration::from_secs(15); + const RESET_AFTER: Duration = Duration::from_secs(30); +} + +pub enum SessionMsg {} + +#[ractor::async_trait] +impl Actor for SessionActor { + type Msg = SessionMsg; + type State = SessionState; + type Arguments = SessionContext; + + async fn pre_start( + &self, + myself: ActorRef, + ctx: Self::Arguments, + ) -> Result { + let session_id = ctx.params.session_id.clone(); + let span = session_span(&session_id); + + async { + let (source_ref, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: ctx.params.onboarding, + app: ctx.app.clone(), + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + + let mode = ChannelMode::determine(ctx.params.onboarding); + let (listener_ref, _) = Actor::spawn_linked( + Some(ListenerActor::name()), + ListenerActor, + ListenerArgs { + app: ctx.app.clone(), + languages: ctx.params.languages.clone(), + onboarding: ctx.params.onboarding, + model: ctx.params.model.clone(), + base_url: ctx.params.base_url.clone(), + api_key: ctx.params.api_key.clone(), + keywords: ctx.params.keywords.clone(), + mode, + session_started_at: ctx.started_at_instant, + session_started_at_unix: ctx.started_at_system, + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + + let recorder_cell = if ctx.params.record_enabled { + let (recorder_ref, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: ctx.app_dir.clone(), session_id: ctx.params.session_id.clone(), }, - supervisor_cell, + myself.get_cell(), ) .await?; - Ok(actor_ref.get_cell()) + Some(recorder_ref.get_cell()) + } else { + None + }; + + Ok(SessionState { + ctx, + source_cell: Some(source_ref.get_cell()), + listener_cell: Some(listener_ref.get_cell()), + recorder_cell, + listener_degraded: None, + source_restarts: RestartTracker::new(), + recorder_restarts: RestartTracker::new(), + }) + } + .instrument(span) + .await + } + + async fn handle( + &self, + _myself: ActorRef, + _message: Self::Msg, + _state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + message: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + let span = session_span(&state.ctx.params.session_id); + let _guard = span.enter(); + + state.source_restarts.maybe_reset(Self::RESET_AFTER); + state.recorder_restarts.maybe_reset(Self::RESET_AFTER); + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, reason) => { + match identify_child(state, &cell) { + Some(ChildKind::Listener) => { + tracing::info!(?reason, "listener_terminated_entering_degraded_mode"); + let degraded = reason + .as_ref() + .and_then(|r| serde_json::from_str::(r).ok()); + state.listener_degraded = degraded.clone(); + state.listener_cell = None; + + let _ = (SessionLifecycleEvent::Active { + session_id: state.ctx.params.session_id.clone(), + error: degraded, + }) + .emit(&state.ctx.app); + } + Some(ChildKind::Source) => { + tracing::info!(?reason, "source_terminated_attempting_restart"); + state.source_cell = None; + if !try_restart_source(myself.get_cell(), state).await { + tracing::error!("source_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + Some(ChildKind::Recorder) => { + tracing::info!(?reason, "recorder_terminated_attempting_restart"); + state.recorder_cell = None; + if !try_restart_recorder(myself.get_cell(), state).await { + tracing::error!("recorder_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + None => { + tracing::warn!("unknown_child_terminated"); + } + } } - }), - backoff_fn: Some(make_listener_backoff()), - reset_after: Some(Duration::from_secs(30)), - }); - - if ctx.params.record_enabled { - let ctx_recorder = ctx.clone(); - child_specs.push(ChildSpec { - id: RecorderActor::name().to_string(), - restart: Restart::Transient, - spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { - let ctx = ctx_recorder.clone(); - async move { - let (actor_ref, _) = Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: ctx.app_dir.clone(), - session_id: ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await?; - Ok(actor_ref.get_cell()) + + SupervisionEvent::ActorFailed(cell, error) => match identify_child(state, &cell) { + Some(ChildKind::Listener) => { + tracing::info!(?error, "listener_failed_entering_degraded_mode"); + let degraded = DegradedError::StreamError { + message: format!("{:?}", error), + }; + state.listener_degraded = Some(degraded.clone()); + state.listener_cell = None; + + let _ = (SessionLifecycleEvent::Active { + session_id: state.ctx.params.session_id.clone(), + error: Some(degraded), + }) + .emit(&state.ctx.app); + } + Some(ChildKind::Source) => { + tracing::warn!(?error, "source_failed_attempting_restart"); + state.source_cell = None; + if !try_restart_source(myself.get_cell(), state).await { + tracing::error!("source_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + Some(ChildKind::Recorder) => { + tracing::warn!(?error, "recorder_failed_attempting_restart"); + state.recorder_cell = None; + if !try_restart_recorder(myself.get_cell(), state).await { + tracing::error!("recorder_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + None => { + tracing::warn!("unknown_child_failed"); } - }), - backoff_fn: None, - reset_after: None, - }); + }, + } + Ok(()) + } +} + +fn identify_child(state: &SessionState, cell: &ActorCell) -> Option { + if state + .source_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Source); + } + if state + .listener_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Listener); } + if state + .recorder_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Recorder); + } + None +} + +async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state + .source_restarts + .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) + { + return false; + } + + match Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: state.ctx.params.onboarding, + app: state.ctx.app.clone(), + session_id: state.ctx.params.session_id.clone(), + }, + supervisor_cell, + ) + .await + { + Ok((actor_ref, _)) => { + state.source_cell = Some(actor_ref.get_cell()); + tracing::info!("source_restarted"); + true + } + Err(e) => { + tracing::error!(error = ?e, "source_restart_failed"); + false + } + } +} - let args = SupervisorArguments { - child_specs, - options: make_supervisor_options(), - }; +async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state.ctx.params.record_enabled { + return true; + } + + if !state + .recorder_restarts + .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) + { + return false; + } + + match Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: state.ctx.app_dir.clone(), + session_id: state.ctx.params.session_id.clone(), + }, + supervisor_cell, + ) + .await + { + Ok((actor_ref, _)) => { + state.recorder_cell = Some(actor_ref.get_cell()); + tracing::info!("recorder_restarted"); + true + } + Err(e) => { + tracing::error!(error = ?e, "recorder_restart_failed"); + false + } + } +} + +async fn meltdown(myself: ActorRef, state: &mut SessionState) { + if let Some(cell) = state.source_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + if let Some(cell) = state.listener_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + if let Some(cell) = state.recorder_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + myself.stop(Some("restart_limit_exceeded".to_string())); +} + +pub async fn spawn_session_supervisor( + ctx: SessionContext, +) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { + let supervisor_name = session_supervisor_name(&ctx.params.session_id); - let (supervisor_ref, handle) = Supervisor::spawn(supervisor_name, args).await?; + let (actor_ref, handle) = Actor::spawn(Some(supervisor_name), SessionActor, ctx).await?; - Ok((supervisor_ref.get_cell(), handle)) + Ok((actor_ref.get_cell(), handle)) } diff --git a/plugins/listener/src/events.rs b/plugins/listener/src/events.rs index 876a379dc0..a8449f75eb 100644 --- a/plugins/listener/src/events.rs +++ b/plugins/listener/src/events.rs @@ -19,7 +19,11 @@ common_event_derives! { error: Option, }, #[serde(rename = "active")] - Active { session_id: String }, + Active { + session_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, #[serde(rename = "finalizing")] Finalizing { session_id: String }, } From 9792f80be0ae0f9d8e03400df6981e64948b1cc2 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 14 Feb 2026 04:45:51 +0000 Subject: [PATCH 2/5] Fix review: add shutting_down flag to prevent restart during shutdown, move restart counter after successful spawn Co-Authored-By: yujonglee --- plugins/listener/src/actors/root.rs | 11 ++-- plugins/listener/src/actors/session/mod.rs | 61 +++++++++++++++------- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs index f8cea6e48d..34a669d99e 100644 --- a/plugins/listener/src/actors/root.rs +++ b/plugins/listener/src/actors/root.rs @@ -8,9 +8,10 @@ use tracing::Instrument; use crate::SessionLifecycleEvent; use crate::actors::session::lifecycle::{ clear_sentry_session_context, configure_sentry_session_context, emit_session_ended, - stop_actor_by_name_and_wait, }; -use crate::actors::{SessionContext, SessionParams, session_span, spawn_session_supervisor}; +use crate::actors::{ + SessionContext, SessionMsg, SessionParams, session_span, spawn_session_supervisor, +}; pub enum RootMsg { StartSession(SessionParams, RpcReplyPort), @@ -213,9 +214,7 @@ async fn stop_session_impl(state: &mut RootState) { } } - // TO make sure post_stop is called. - stop_actor_by_name_and_wait(crate::actors::RecorderActor::name(), "session_stop").await; - - supervisor.stop(None); + let session_ref: ActorRef = supervisor.clone().into(); + let _ = session_ref.cast(SessionMsg::Shutdown); } } diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs index e89b5ce411..1dd5517617 100644 --- a/plugins/listener/src/actors/session/mod.rs +++ b/plugins/listener/src/actors/session/mod.rs @@ -93,6 +93,7 @@ pub struct SessionState { listener_degraded: Option, source_restarts: RestartTracker, recorder_restarts: RestartTracker, + shutting_down: bool, } pub struct SessionActor; @@ -103,7 +104,9 @@ impl SessionActor { const RESET_AFTER: Duration = Duration::from_secs(30); } -pub enum SessionMsg {} +pub enum SessionMsg { + Shutdown, +} #[ractor::async_trait] impl Actor for SessionActor { @@ -178,6 +181,7 @@ impl Actor for SessionActor { listener_degraded: None, source_restarts: RestartTracker::new(), recorder_restarts: RestartTracker::new(), + shutting_down: false, }) } .instrument(span) @@ -186,10 +190,29 @@ impl Actor for SessionActor { async fn handle( &self, - _myself: ActorRef, - _message: Self::Msg, - _state: &mut Self::State, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { + match message { + SessionMsg::Shutdown => { + state.shutting_down = true; + + if let Some(cell) = state.recorder_cell.take() { + cell.stop(Some("session_stop".to_string())); + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + } + + if let Some(cell) = state.source_cell.take() { + cell.stop(Some("session_stop".to_string())); + } + if let Some(cell) = state.listener_cell.take() { + cell.stop(Some("session_stop".to_string())); + } + + myself.stop(None); + } + } Ok(()) } @@ -205,6 +228,10 @@ impl Actor for SessionActor { state.source_restarts.maybe_reset(Self::RESET_AFTER); state.recorder_restarts.maybe_reset(Self::RESET_AFTER); + if state.shutting_down { + return Ok(()); + } + match message { SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} @@ -312,13 +339,6 @@ fn identify_child(state: &SessionState, cell: &ActorCell) -> Option { } async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { - if !state - .source_restarts - .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) - { - return false; - } - match Actor::spawn_linked( Some(SourceActor::name()), SourceActor, @@ -333,6 +353,12 @@ async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState .await { Ok((actor_ref, _)) => { + if !state + .source_restarts + .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) + { + return false; + } state.source_cell = Some(actor_ref.get_cell()); tracing::info!("source_restarted"); true @@ -349,13 +375,6 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta return true; } - if !state - .recorder_restarts - .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) - { - return false; - } - match Actor::spawn_linked( Some(RecorderActor::name()), RecorderActor, @@ -368,6 +387,12 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta .await { Ok((actor_ref, _)) => { + if !state + .recorder_restarts + .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) + { + return false; + } state.recorder_cell = Some(actor_ref.get_cell()); tracing::info!("recorder_restarted"); true From 0f7bf1165b05ead45cfa9c0eac06799db2d43225 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 14 Feb 2026 05:01:55 +0000 Subject: [PATCH 3/5] Fix review issues: check restart limit before spawn, add backoff delay, propagate DegradedError, add Debug to SessionMsg Co-Authored-By: yujonglee --- plugins/listener/src/actors/root.rs | 9 ++++++- plugins/listener/src/actors/session/mod.rs | 31 +++++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs index 34a669d99e..0f6002d7d8 100644 --- a/plugins/listener/src/actors/root.rs +++ b/plugins/listener/src/actors/root.rs @@ -5,6 +5,7 @@ use tauri_plugin_settings::SettingsPluginExt; use tauri_specta::Event; use tracing::Instrument; +use crate::DegradedError; use crate::SessionLifecycleEvent; use crate::actors::session::lifecycle::{ clear_sentry_session_context, configure_sentry_session_context, emit_session_ended, @@ -104,7 +105,13 @@ impl Actor for RootActor { tracing::info!(?reason, "session_supervisor_terminated"); state.supervisor = None; state.finalizing = false; - emit_session_ended(&state.app, &session_id, None); + + let failure_reason = reason.and_then(|r| { + serde_json::from_str::(&r) + .ok() + .map(|d| format!("{:?}", d)) + }); + emit_session_ended(&state.app, &session_id, failure_reason); } } SupervisionEvent::ActorFailed(cell, error) => { diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs index 1dd5517617..298b2cfa58 100644 --- a/plugins/listener/src/actors/session/mod.rs +++ b/plugins/listener/src/actors/session/mod.rs @@ -104,6 +104,7 @@ impl SessionActor { const RESET_AFTER: Duration = Duration::from_secs(30); } +#[derive(Debug)] pub enum SessionMsg { Shutdown, } @@ -339,6 +340,15 @@ fn identify_child(state: &SessionState, cell: &ActorCell) -> Option { } async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state + .source_restarts + .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) + { + return false; + } + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + match Actor::spawn_linked( Some(SourceActor::name()), SourceActor, @@ -353,12 +363,6 @@ async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState .await { Ok((actor_ref, _)) => { - if !state - .source_restarts - .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) - { - return false; - } state.source_cell = Some(actor_ref.get_cell()); tracing::info!("source_restarted"); true @@ -375,6 +379,15 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta return true; } + if !state + .recorder_restarts + .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) + { + return false; + } + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + match Actor::spawn_linked( Some(RecorderActor::name()), RecorderActor, @@ -387,12 +400,6 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta .await { Ok((actor_ref, _)) => { - if !state - .recorder_restarts - .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) - { - return false; - } state.recorder_cell = Some(actor_ref.get_cell()); tracing::info!("recorder_restarted"); true From 4b875bb6d499eda0bb0fb39dd20c2d775c0e9901 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 14 Feb 2026 06:41:48 +0000 Subject: [PATCH 4/5] Address review: retry spawn with exponential backoff, wait for recorder in meltdown, fallback stop reason Co-Authored-By: yujonglee --- plugins/listener/src/actors/root.rs | 1 + plugins/listener/src/actors/session/mod.rs | 99 ++++++++++++---------- 2 files changed, 56 insertions(+), 44 deletions(-) diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs index 0f6002d7d8..17abd6e0c4 100644 --- a/plugins/listener/src/actors/root.rs +++ b/plugins/listener/src/actors/root.rs @@ -110,6 +110,7 @@ impl Actor for RootActor { serde_json::from_str::(&r) .ok() .map(|d| format!("{:?}", d)) + .or(Some(r)) }); emit_session_ended(&state.app, &session_id, failure_reason); } diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs index 298b2cfa58..1889f43b59 100644 --- a/plugins/listener/src/actors/session/mod.rs +++ b/plugins/listener/src/actors/session/mod.rs @@ -347,31 +347,36 @@ async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState return false; } - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - - match Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - mic_device: None, - onboarding: state.ctx.params.onboarding, - app: state.ctx.app.clone(), - session_id: state.ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await - { - Ok((actor_ref, _)) => { - state.source_cell = Some(actor_ref.get_cell()); - tracing::info!("source_restarted"); - true - } - Err(e) => { - tracing::error!(error = ?e, "source_restart_failed"); - false + for attempt in 0..3u32 { + let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt)); + tokio::time::sleep(delay).await; + + match Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: state.ctx.params.onboarding, + app: state.ctx.app.clone(), + session_id: state.ctx.params.session_id.clone(), + }, + supervisor_cell.clone(), + ) + .await + { + Ok((actor_ref, _)) => { + state.source_cell = Some(actor_ref.get_cell()); + tracing::info!(attempt, "source_restarted"); + return true; + } + Err(e) => { + tracing::warn!(attempt, error = ?e, "source_spawn_attempt_failed"); + } } } + + tracing::error!("source_restart_failed_all_attempts"); + false } async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { @@ -386,29 +391,34 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta return false; } - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - - match Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: state.ctx.app_dir.clone(), - session_id: state.ctx.params.session_id.clone(), - }, - supervisor_cell, - ) - .await - { - Ok((actor_ref, _)) => { - state.recorder_cell = Some(actor_ref.get_cell()); - tracing::info!("recorder_restarted"); - true - } - Err(e) => { - tracing::error!(error = ?e, "recorder_restart_failed"); - false + for attempt in 0..3u32 { + let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt)); + tokio::time::sleep(delay).await; + + match Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: state.ctx.app_dir.clone(), + session_id: state.ctx.params.session_id.clone(), + }, + supervisor_cell.clone(), + ) + .await + { + Ok((actor_ref, _)) => { + state.recorder_cell = Some(actor_ref.get_cell()); + tracing::info!(attempt, "recorder_restarted"); + return true; + } + Err(e) => { + tracing::warn!(attempt, error = ?e, "recorder_spawn_attempt_failed"); + } } } + + tracing::error!("recorder_restart_failed_all_attempts"); + false } async fn meltdown(myself: ActorRef, state: &mut SessionState) { @@ -420,6 +430,7 @@ async fn meltdown(myself: ActorRef, state: &mut SessionState) { } if let Some(cell) = state.recorder_cell.take() { cell.stop(Some("meltdown".to_string())); + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; } myself.stop(Some("restart_limit_exceeded".to_string())); } From 93a3b7330f7b6946006c22d93f4beda019a717f1 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Sat, 14 Feb 2026 18:26:41 +0900 Subject: [PATCH 5/5] extract custom supervisor --- Cargo.lock | 63 +- Cargo.toml | 4 +- apps/desktop/src-tauri/Cargo.toml | 2 +- apps/desktop/src-tauri/src/supervisor.rs | 4 +- crates/supervisor/Cargo.toml | 10 + crates/supervisor/src/dynamic.rs | 806 ++++++++++++++++++ crates/supervisor/src/lib.rs | 10 + crates/supervisor/src/restart.rs | 164 ++++ crates/supervisor/src/retry.rs | 187 ++++ crates/supervisor/src/supervisor.rs | 615 +++++++++++++ plugins/listener/Cargo.toml | 3 +- plugins/listener/js/bindings.gen.ts | 3 +- plugins/listener/src/actors/root.rs | 17 +- plugins/listener/src/actors/session/mod.rs | 448 +--------- .../listener/src/actors/session/supervisor.rs | 419 +++++++++ plugins/listener/src/actors/session/types.rs | 33 + plugins/local-stt/Cargo.toml | 4 +- plugins/local-stt/src/lib.rs | 2 +- plugins/local-stt/src/server/supervisor.rs | 27 +- plugins/network/Cargo.toml | 1 - 20 files changed, 2326 insertions(+), 496 deletions(-) create mode 100644 crates/supervisor/Cargo.toml create mode 100644 crates/supervisor/src/dynamic.rs create mode 100644 crates/supervisor/src/lib.rs create mode 100644 crates/supervisor/src/restart.rs create mode 100644 crates/supervisor/src/retry.rs create mode 100644 crates/supervisor/src/supervisor.rs create mode 100644 plugins/listener/src/actors/session/supervisor.rs create mode 100644 plugins/listener/src/actors/session/types.rs diff --git a/Cargo.lock b/Cargo.lock index ca0eef1538..2ec3a7e58a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4691,13 +4691,13 @@ dependencies = [ "intercept", "pico-args", "ractor", - "ractor-supervisor", "sentry", "serde", "serde_json", "specta", "specta-typescript", "strum 0.27.2", + "supervisor", "tauri", "tauri-build", "tauri-plugin-analytics", @@ -14393,32 +14393,23 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ractor" -version = "0.14.7" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d65972a0286ef14c43c6daafbac6cf15e96496446147683b2905292c35cc178" +checksum = "6102314f700f3e8df466c49110830b18cbfc172f88f27a9d7383e455663b1be7" dependencies = [ "async-trait", "bon 2.3.0", "dashmap", "futures", + "js-sys", "once_cell", "strum 0.26.3", "tokio", + "tokio_with_wasm", "tracing", -] - -[[package]] -name = "ractor-supervisor" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d90830688ebfafdc226f3c9567c40fecf4c51a7513171181102ae66e4b57c15f" -dependencies = [ - "futures-util", - "if_chain", - "log", - "ractor", - "thiserror 2.0.18", - "uuid", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-time", ] [[package]] @@ -17506,6 +17497,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "supervisor" +version = "0.1.0" +dependencies = [ + "ractor", + "thiserror 2.0.18", + "tokio", + "tracing", +] + [[package]] name = "svgtypes" version = "0.15.3" @@ -18917,6 +18918,7 @@ dependencies = [ "specta", "specta-typescript", "strum 0.27.2", + "supervisor", "tauri", "tauri-plugin", "tauri-plugin-fs-sync", @@ -19022,7 +19024,6 @@ dependencies = [ "port-killer", "port_check", "ractor", - "ractor-supervisor", "reqwest 0.13.2", "rodio", "serde", @@ -19031,6 +19032,7 @@ dependencies = [ "specta", "specta-typescript", "strum 0.27.2", + "supervisor", "tauri", "tauri-plugin", "tauri-plugin-settings", @@ -19089,7 +19091,6 @@ name = "tauri-plugin-network" version = "0.1.0" dependencies = [ "ractor", - "ractor-supervisor", "reqwest 0.13.2", "serde", "specta", @@ -20394,6 +20395,30 @@ dependencies = [ "webpki-roots 0.26.11", ] +[[package]] +name = "tokio_with_wasm" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34e40fbbbd95441133fe9483f522db15dbfd26dc636164ebd8f2dd28759a6aa6" +dependencies = [ + "js-sys", + "tokio", + "tokio_with_wasm_proc", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "tokio_with_wasm_proc" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d01145a2c788d6aae4cd653afec1e8332534d7d783d01897cefcafe4428de992" +dependencies = [ + "quote", + "syn 2.0.114", +] + [[package]] name = "toml" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index 086fc7c49d..f9091e569f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -198,8 +198,8 @@ async-stream = "0.3.6" futures-channel = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" -ractor = "0.14" -ractor-supervisor = "0.1.9" +hypr-supervisor = { path = "crates/supervisor", package = "supervisor" } +ractor = "0.15.10" rayon = "1.11" reqwest = "0.13" reqwest-middleware = "0.5" diff --git a/apps/desktop/src-tauri/Cargo.toml b/apps/desktop/src-tauri/Cargo.toml index fe70712dc5..16c940f4e0 100644 --- a/apps/desktop/src-tauri/Cargo.toml +++ b/apps/desktop/src-tauri/Cargo.toml @@ -94,8 +94,8 @@ tracing = { workspace = true } hypr-host = { workspace = true } +hypr-supervisor = { workspace = true } ractor = { workspace = true } -ractor-supervisor = { workspace = true } pico-args = "0.5" tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/apps/desktop/src-tauri/src/supervisor.rs b/apps/desktop/src-tauri/src/supervisor.rs index b12d913795..cd8cb1c8df 100644 --- a/apps/desktop/src-tauri/src/supervisor.rs +++ b/apps/desktop/src-tauri/src/supervisor.rs @@ -1,11 +1,9 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; +use hypr_supervisor::dynamic::{DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions}; use ractor::ActorRef; use ractor::concurrency::Duration; -use ractor_supervisor::dynamic::{ - DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions, -}; pub type SupervisorRef = ActorRef; pub type SupervisorHandle = tokio::task::JoinHandle<()>; diff --git a/crates/supervisor/Cargo.toml b/crates/supervisor/Cargo.toml new file mode 100644 index 0000000000..abe4447dad --- /dev/null +++ b/crates/supervisor/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "supervisor" +version = "0.1.0" +edition = "2024" + +[dependencies] +ractor = { workspace = true, features = ["async-trait"] } +thiserror = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } diff --git a/crates/supervisor/src/dynamic.rs b/crates/supervisor/src/dynamic.rs new file mode 100644 index 0000000000..f6f864ba2d --- /dev/null +++ b/crates/supervisor/src/dynamic.rs @@ -0,0 +1,806 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use ractor::{ + Actor, ActorCell, ActorId, ActorProcessingErr, ActorRef, RpcReplyPort, SpawnErr, + SupervisionEvent, concurrency::JoinHandle, +}; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, thiserror::Error)] +pub enum SupervisorError { + #[error("Child '{child_id}' not found in specs")] + ChildNotFound { child_id: String }, + + #[error("Child '{pid}' does not have a name set")] + ChildNameNotSet { pid: ActorId }, + + #[error("Max children exceeded")] + MaxChildrenExceeded, + + #[error("Meltdown: {reason}")] + Meltdown { reason: String }, +} + +pub type DynSpawnFuture = Pin> + Send>>; + +#[derive(Clone)] +pub struct DynSpawnFn(Arc DynSpawnFuture + Send + Sync>); + +impl DynSpawnFn { + pub fn new(f: F) -> Self + where + F: Fn(ActorCell, String) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + Self(Arc::new(move |cell, id| Box::pin(f(cell, id)))) + } + + pub async fn call(&self, sup: ActorCell, id: String) -> Result { + (self.0)(sup, id).await + } +} + +type BackoffFn = dyn Fn(&str, usize, Instant, Option) -> Option + Send + Sync; + +#[derive(Clone)] +pub struct ChildBackoffFn(Arc); + +impl ChildBackoffFn { + pub fn new(f: F) -> Self + where + F: Fn(&str, usize, Instant, Option) -> Option + Send + Sync + 'static, + { + Self(Arc::new(f)) + } + + fn call( + &self, + child_id: &str, + restart_count: usize, + last_fail: Instant, + reset_after: Option, + ) -> Option { + (self.0)(child_id, restart_count, last_fail, reset_after) + } +} + +#[derive(Clone)] +pub struct DynChildSpec { + pub id: String, + pub restart: crate::RestartPolicy, + pub spawn_fn: DynSpawnFn, + pub backoff_fn: Option, + pub reset_after: Option, +} + +#[derive(Debug, Clone)] +pub struct DynamicSupervisorOptions { + pub max_children: Option, + pub max_restarts: usize, + pub max_window: Duration, + pub reset_after: Option, +} + +// --------------------------------------------------------------------------- +// Messages +// --------------------------------------------------------------------------- + +pub enum DynamicSupervisorMsg { + SpawnChild { + spec: DynChildSpec, + reply: Option>>, + }, + TerminateChild { + child_id: String, + reply: Option>, + }, + ScheduledRestart { + spec: DynChildSpec, + }, +} + +impl std::fmt::Debug for DynamicSupervisorMsg { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::SpawnChild { spec, .. } => { + f.debug_struct("SpawnChild").field("id", &spec.id).finish() + } + Self::TerminateChild { child_id, .. } => f + .debug_struct("TerminateChild") + .field("child_id", child_id) + .finish(), + Self::ScheduledRestart { spec } => f + .debug_struct("ScheduledRestart") + .field("id", &spec.id) + .finish(), + } + } +} + +// --------------------------------------------------------------------------- +// Internal state +// --------------------------------------------------------------------------- + +struct ActiveChild { + spec: DynChildSpec, + cell: ActorCell, +} + +struct ChildFailureState { + restart_count: usize, + last_fail: Instant, +} + +struct RestartLogEntry { + _child_id: String, + timestamp: Instant, +} + +pub struct DynamicSupervisorState { + options: DynamicSupervisorOptions, + active_children: HashMap, + child_failure_state: HashMap, + restart_log: Vec, +} + +// --------------------------------------------------------------------------- +// Actor +// --------------------------------------------------------------------------- + +pub struct DynamicSupervisor; + +impl DynamicSupervisor { + pub async fn spawn( + name: String, + options: DynamicSupervisorOptions, + ) -> Result<(ActorRef, JoinHandle<()>), SpawnErr> { + Actor::spawn(Some(name), DynamicSupervisor, options).await + } + + pub async fn spawn_linked( + name: impl Into, + handler: T, + args: T::Arguments, + supervisor: ActorCell, + ) -> Result<(ActorRef, JoinHandle<()>), SpawnErr> { + Actor::spawn_linked(Some(name.into()), handler, args, supervisor).await + } + + pub async fn spawn_child( + sup_ref: ActorRef, + spec: DynChildSpec, + ) -> Result<(), ActorProcessingErr> { + ractor::call!(sup_ref, |reply| { + DynamicSupervisorMsg::SpawnChild { + spec, + reply: Some(reply), + } + })? + } + + pub async fn terminate_child( + sup_ref: ActorRef, + child_id: String, + ) -> Result<(), ActorProcessingErr> { + ractor::call!(sup_ref, |reply| { + DynamicSupervisorMsg::TerminateChild { + child_id, + reply: Some(reply), + } + })?; + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Meltdown tracking +// --------------------------------------------------------------------------- + +impl DynamicSupervisorState { + fn track_global_restart(&mut self, child_id: &str) -> Result<(), ActorProcessingErr> { + let now = Instant::now(); + + if let Some(reset_after) = self.options.reset_after { + if let Some(latest) = self.restart_log.last() { + if now.duration_since(latest.timestamp) >= reset_after { + self.restart_log.clear(); + } + } + } + + self.restart_log.push(RestartLogEntry { + _child_id: child_id.to_string(), + timestamp: now, + }); + + self.restart_log + .retain(|e| now.duration_since(e.timestamp) < self.options.max_window); + + if self.restart_log.len() > self.options.max_restarts { + Err(SupervisorError::Meltdown { + reason: "max_restarts exceeded".to_string(), + } + .into()) + } else { + Ok(()) + } + } + + fn prepare_child_failure(&mut self, spec: &DynChildSpec) { + let now = Instant::now(); + let entry = self + .child_failure_state + .entry(spec.id.clone()) + .or_insert(ChildFailureState { + restart_count: 0, + last_fail: now, + }); + + if let Some(threshold) = spec.reset_after { + if now.duration_since(entry.last_fail) >= threshold { + entry.restart_count = 0; + } + } + + entry.restart_count += 1; + entry.last_fail = now; + } +} + +// --------------------------------------------------------------------------- +// Actor implementation +// --------------------------------------------------------------------------- + +#[ractor::async_trait] +impl Actor for DynamicSupervisor { + type Msg = DynamicSupervisorMsg; + type State = DynamicSupervisorState; + type Arguments = DynamicSupervisorOptions; + + async fn pre_start( + &self, + _myself: ActorRef, + options: Self::Arguments, + ) -> Result { + Ok(DynamicSupervisorState { + options, + active_children: HashMap::new(), + child_failure_state: HashMap::new(), + restart_log: Vec::new(), + }) + } + + async fn handle( + &self, + myself: ActorRef, + msg: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match msg { + DynamicSupervisorMsg::SpawnChild { spec, reply } => { + let result = + handle_spawn_child(&spec, reply.is_some(), state, myself.clone()).await; + if let Some(reply) = reply { + reply.send(result)?; + Ok(()) + } else { + result + } + } + DynamicSupervisorMsg::TerminateChild { child_id, reply } => { + handle_terminate_child(&child_id, state, &myself); + if let Some(reply) = reply { + reply.send(())?; + } + Ok(()) + } + DynamicSupervisorMsg::ScheduledRestart { spec } => { + handle_spawn_child(&spec, false, state, myself).await + } + } + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + evt: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match evt { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, reason) => { + handle_child_restart(cell, false, state, &myself, reason.as_deref())?; + } + + SupervisionEvent::ActorFailed(cell, err) => { + let reason = format!("{:?}", err); + handle_child_restart(cell, true, state, &myself, Some(&reason))?; + } + } + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +async fn handle_spawn_child( + spec: &DynChildSpec, + first_start: bool, + state: &mut DynamicSupervisorState, + myself: ActorRef, +) -> Result<(), ActorProcessingErr> { + if !first_start { + state.track_global_restart(&spec.id)?; + tokio::time::sleep(Duration::from_millis(10)).await; + } + + if let Some(max) = state.options.max_children { + if state.active_children.len() >= max { + return Err(SupervisorError::MaxChildrenExceeded.into()); + } + } + + let result = spec.spawn_fn.call(myself.get_cell(), spec.id.clone()).await; + + match result { + Ok(child_cell) => { + state.active_children.insert( + spec.id.clone(), + ActiveChild { + spec: spec.clone(), + cell: child_cell, + }, + ); + state + .child_failure_state + .entry(spec.id.clone()) + .or_insert(ChildFailureState { + restart_count: 0, + last_fail: Instant::now(), + }); + Ok(()) + } + Err(e) => Err(SupervisorError::Meltdown { + reason: format!("spawn failed for '{}': {}", spec.id, e), + } + .into()), + } +} + +fn handle_terminate_child( + child_id: &str, + state: &mut DynamicSupervisorState, + myself: &ActorRef, +) { + if let Some(child) = state.active_children.remove(child_id) { + child.cell.unlink(myself.get_cell()); + child.cell.kill(); + } +} + +fn handle_child_restart( + cell: ActorCell, + abnormal: bool, + state: &mut DynamicSupervisorState, + myself: &ActorRef, + _reason: Option<&str>, +) -> Result<(), ActorProcessingErr> { + let child_id = cell + .get_name() + .ok_or(SupervisorError::ChildNameNotSet { pid: cell.get_id() })?; + + let child = match state.active_children.remove(&child_id) { + Some(c) => c, + None => return Ok(()), + }; + + let should_restart = match child.spec.restart { + crate::RestartPolicy::Permanent => true, + crate::RestartPolicy::Transient => abnormal, + crate::RestartPolicy::Temporary => false, + }; + + if !should_restart { + return Ok(()); + } + + state.prepare_child_failure(&child.spec); + + let delay = child.spec.backoff_fn.as_ref().and_then(|bf| { + let fs = state.child_failure_state.get(&child.spec.id); + let (count, last_fail) = fs + .map(|f| (f.restart_count, f.last_fail)) + .unwrap_or((0, Instant::now())); + bf.call(&child.spec.id, count, last_fail, child.spec.reset_after) + }); + + let spec = child.spec.clone(); + match delay { + Some(d) => { + let dur = ractor::concurrency::Duration::from_millis(d.as_millis() as u64); + myself.send_after(dur, move || DynamicSupervisorMsg::ScheduledRestart { spec }); + } + None => { + myself.send_message(DynamicSupervisorMsg::ScheduledRestart { spec })?; + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::RestartPolicy; + use ractor::{ActorRef, ActorStatus}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; + + static TEST_SEQ: AtomicU64 = AtomicU64::new(0); + + fn unique_name(prefix: &str) -> String { + format!( + "{prefix}_{}_{}", + std::process::id(), + TEST_SEQ.fetch_add(1, Ordering::Relaxed) + ) + } + + #[derive(Clone)] + enum ChildBehavior { + Healthy, + DelayedFail { ms: u64 }, + DelayedNormal { ms: u64 }, + } + + struct TestChild { + counter: Arc, + } + + #[ractor::async_trait] + impl Actor for TestChild { + type Msg = (); + type State = ChildBehavior; + type Arguments = ChildBehavior; + + async fn pre_start( + &self, + myself: ActorRef, + behavior: Self::Arguments, + ) -> Result { + self.counter.fetch_add(1, Ordering::SeqCst); + + match behavior { + ChildBehavior::DelayedFail { ms } | ChildBehavior::DelayedNormal { ms } => { + myself.send_after(Duration::from_millis(ms), || ()); + } + ChildBehavior::Healthy => {} + } + + Ok(behavior) + } + + async fn handle( + &self, + myself: ActorRef, + _msg: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match state { + ChildBehavior::DelayedFail { .. } => panic!("delayed_fail"), + ChildBehavior::DelayedNormal { .. } => myself.stop(None), + ChildBehavior::Healthy => {} + } + Ok(()) + } + } + + fn make_spec( + id: &str, + restart: RestartPolicy, + behavior: ChildBehavior, + counter: Arc, + ) -> DynChildSpec { + let id = id.to_string(); + DynChildSpec { + id: id.clone(), + restart, + spawn_fn: DynSpawnFn::new(move |sup_cell, child_id| { + let behavior = behavior.clone(); + let counter = counter.clone(); + async move { + let (child_ref, _join) = DynamicSupervisor::spawn_linked( + child_id, + TestChild { counter }, + behavior, + sup_cell, + ) + .await?; + Ok(child_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: None, + } + } + + fn options(max_restarts: usize) -> DynamicSupervisorOptions { + DynamicSupervisorOptions { + max_children: None, + max_restarts, + max_window: Duration::from_secs(5), + reset_after: None, + } + } + + #[tokio::test] + async fn transient_child_no_restart_on_normal_exit() { + let sup_name = unique_name("dyn_transient_normal_sup"); + let child_name = unique_name("dyn_transient_normal_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, options(5)) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Transient, + ChildBehavior::DelayedNormal { ms: 50 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(180)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert!( + !sup_ref + .get_children() + .iter() + .any(|c| c.get_status() == ActorStatus::Running) + ); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn temporary_child_never_restarts_on_failure() { + let sup_name = unique_name("dyn_temporary_sup"); + let child_name = unique_name("dyn_temporary_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, options(5)) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Temporary, + ChildBehavior::DelayedFail { ms: 50 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(180)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert!( + !sup_ref + .get_children() + .iter() + .any(|c| c.get_status() == ActorStatus::Running) + ); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn permanent_child_triggers_meltdown_when_budget_exceeded() { + let sup_name = unique_name("dyn_meltdown_sup"); + let child_name = unique_name("dyn_meltdown_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut opts = options(1); + opts.max_window = Duration::from_secs(2); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 40 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn reset_after_allows_restarts_across_quiet_periods() { + let sup_name = unique_name("dyn_reset_after_sup"); + let child_name = unique_name("dyn_reset_after_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut opts = options(1); + opts.max_window = Duration::from_secs(10); + opts.reset_after = Some(Duration::from_millis(80)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 140 }, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(520)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!(counter.load(Ordering::SeqCst) >= 3); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn max_children_is_enforced() { + let sup_name = unique_name("dyn_max_children_sup"); + let child_name_1 = unique_name("dyn_max_children_child1"); + let child_name_2 = unique_name("dyn_max_children_child2"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut opts = options(5); + opts.max_children = Some(1); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name_1, + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + ), + ) + .await + .expect("first child should spawn"); + + let second = DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name_2, + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + ), + ) + .await; + + assert!(second.is_err()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn terminate_child_does_not_restart() { + let sup_name = unique_name("dyn_terminate_sup"); + let child_name = unique_name("dyn_terminate_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, options(5)) + .await + .expect("failed to spawn dynamic supervisor"); + + DynamicSupervisor::spawn_child( + sup_ref.clone(), + make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + ), + ) + .await + .expect("failed to spawn child"); + + tokio::time::sleep(Duration::from_millis(40)).await; + DynamicSupervisor::terminate_child(sup_ref.clone(), child_name) + .await + .expect("failed to terminate child"); + tokio::time::sleep(Duration::from_millis(120)).await; + + assert_eq!(counter.load(Ordering::SeqCst), 1); + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!( + !sup_ref + .get_children() + .iter() + .any(|c| c.get_status() == ActorStatus::Running) + ); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn backoff_delays_second_restart_attempt() { + let sup_name = unique_name("dyn_backoff_sup"); + let child_name = unique_name("dyn_backoff_child"); + let counter = Arc::new(AtomicU32::new(0)); + + let mut spec = make_spec( + &child_name, + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 10 }, + counter.clone(), + ); + spec.backoff_fn = Some(ChildBackoffFn::new( + |_id, restart_count, _last, _child_reset| { + if restart_count <= 1 { + None + } else { + Some(Duration::from_millis(220)) + } + }, + )); + + let mut opts = options(1); + opts.max_window = Duration::from_secs(5); + + let start = std::time::Instant::now(); + let (sup_ref, sup_handle) = DynamicSupervisor::spawn(sup_name, opts) + .await + .expect("failed to spawn dynamic supervisor"); + DynamicSupervisor::spawn_child(sup_ref.clone(), spec) + .await + .expect("failed to spawn child"); + + let _ = sup_handle.await; + let elapsed = start.elapsed(); + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + assert!( + elapsed >= Duration::from_millis(200), + "expected delayed restart, got {elapsed:?}" + ); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } +} diff --git a/crates/supervisor/src/lib.rs b/crates/supervisor/src/lib.rs new file mode 100644 index 0000000000..d145303e7e --- /dev/null +++ b/crates/supervisor/src/lib.rs @@ -0,0 +1,10 @@ +pub mod dynamic; +mod restart; +mod retry; +mod supervisor; + +pub use restart::{RestartBudget, RestartTracker}; +pub use retry::{RetryStrategy, spawn_with_retry}; +pub use supervisor::{ + ChildSpec, RestartPolicy, SpawnFn, Supervisor, SupervisorConfig, SupervisorMsg, +}; diff --git a/crates/supervisor/src/restart.rs b/crates/supervisor/src/restart.rs new file mode 100644 index 0000000000..0d28c16679 --- /dev/null +++ b/crates/supervisor/src/restart.rs @@ -0,0 +1,164 @@ +use std::time::{Duration, Instant}; + +#[derive(Debug, Clone)] +pub struct RestartBudget { + pub max_restarts: u32, + pub max_window: Duration, + pub reset_after: Option, +} + +impl Default for RestartBudget { + fn default() -> Self { + Self { + max_restarts: 3, + max_window: Duration::from_secs(15), + reset_after: Some(Duration::from_secs(30)), + } + } +} + +pub struct RestartTracker { + count: u32, + window_start: Instant, +} + +impl Default for RestartTracker { + fn default() -> Self { + Self::new() + } +} + +impl RestartTracker { + pub fn new() -> Self { + Self { + count: 0, + window_start: Instant::now(), + } + } + + /// Records a restart and returns whether the budget still allows it. + /// `true` = within budget, `false` = meltdown threshold exceeded. + pub fn record_restart(&mut self, budget: &RestartBudget) -> bool { + let now = Instant::now(); + if now.duration_since(self.window_start) > budget.max_window { + self.count = 0; + self.window_start = now; + } + self.count += 1; + self.count <= budget.max_restarts + } + + /// Resets the counter if quiet for longer than `budget.reset_after`. + pub fn maybe_reset(&mut self, budget: &RestartBudget) { + if let Some(reset_after) = budget.reset_after { + let now = Instant::now(); + if now.duration_since(self.window_start) > reset_after { + self.count = 0; + self.window_start = now; + } + } + } + + pub fn count(&self) -> u32 { + self.count + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn budget(max_restarts: u32, max_window_ms: u64) -> RestartBudget { + RestartBudget { + max_restarts, + max_window: Duration::from_millis(max_window_ms), + reset_after: None, + } + } + + #[test] + fn new_tracker_starts_at_zero() { + let tracker = RestartTracker::new(); + assert_eq!(tracker.count(), 0); + } + + #[test] + fn record_restart_increments_within_budget() { + let mut tracker = RestartTracker::new(); + let b = budget(3, 10_000); + assert!(tracker.record_restart(&b)); + assert!(tracker.record_restart(&b)); + assert!(tracker.record_restart(&b)); + assert_eq!(tracker.count(), 3); + } + + #[test] + fn record_restart_exceeds_budget() { + let mut tracker = RestartTracker::new(); + let b = budget(2, 10_000); + assert!(tracker.record_restart(&b)); + assert!(tracker.record_restart(&b)); + assert!(!tracker.record_restart(&b)); + } + + #[test] + fn record_restart_at_exact_boundary() { + let mut tracker = RestartTracker::new(); + let b = budget(1, 10_000); + assert!(tracker.record_restart(&b)); + assert!(!tracker.record_restart(&b)); + } + + #[test] + fn window_expiry_resets_counter() { + let mut tracker = RestartTracker::new(); + let b = budget(1, 50); + assert!(tracker.record_restart(&b)); + std::thread::sleep(Duration::from_millis(100)); + assert!(tracker.record_restart(&b)); + } + + #[test] + fn maybe_reset_clears_after_quiet_period() { + let mut tracker = RestartTracker::new(); + let mut b = budget(3, 10_000); + b.reset_after = Some(Duration::from_millis(50)); + + tracker.record_restart(&b); + tracker.record_restart(&b); + assert_eq!(tracker.count(), 2); + + std::thread::sleep(Duration::from_millis(100)); + tracker.maybe_reset(&b); + assert_eq!(tracker.count(), 0); + } + + #[test] + fn maybe_reset_noop_before_threshold() { + let mut tracker = RestartTracker::new(); + let mut b = budget(3, 10_000); + b.reset_after = Some(Duration::from_secs(60)); + + tracker.record_restart(&b); + tracker.record_restart(&b); + tracker.maybe_reset(&b); + assert_eq!(tracker.count(), 2); + } + + #[test] + fn maybe_reset_noop_when_none() { + let mut tracker = RestartTracker::new(); + let b = budget(3, 10_000); + + tracker.record_restart(&b); + tracker.maybe_reset(&b); + assert_eq!(tracker.count(), 1); + } + + #[test] + fn zero_budget_always_exceeds() { + let mut tracker = RestartTracker::new(); + let b = budget(0, 10_000); + assert!(!tracker.record_restart(&b)); + } +} diff --git a/crates/supervisor/src/retry.rs b/crates/supervisor/src/retry.rs new file mode 100644 index 0000000000..2f576b49c8 --- /dev/null +++ b/crates/supervisor/src/retry.rs @@ -0,0 +1,187 @@ +use ractor::{ActorCell, SpawnErr}; +use std::future::Future; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct RetryStrategy { + pub max_attempts: u32, + pub base_delay: Duration, +} + +impl Default for RetryStrategy { + fn default() -> Self { + Self { + max_attempts: 3, + base_delay: Duration::from_millis(100), + } + } +} + +/// Spawn an actor with exponential-backoff retries. +/// Returns `Some(cell)` on first success, `None` if all attempts fail. +pub async fn spawn_with_retry(strategy: &RetryStrategy, spawn_fn: F) -> Option +where + F: Fn() -> Fut, + Fut: Future>, +{ + for attempt in 0..strategy.max_attempts { + let delay = strategy.base_delay * 2u32.pow(attempt); + tokio::time::sleep(delay).await; + + match spawn_fn().await { + Ok(cell) => { + tracing::info!(attempt, "spawn_retry_succeeded"); + return Some(cell); + } + Err(e) => { + tracing::warn!(attempt, error = ?e, "spawn_retry_failed"); + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use ractor::{Actor, ActorProcessingErr, ActorRef}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::Instant; + + static TEST_SEQ: AtomicU64 = AtomicU64::new(0); + + fn unique_name(prefix: &str) -> String { + format!( + "{prefix}_{}_{}", + std::process::id(), + TEST_SEQ.fetch_add(1, Ordering::Relaxed) + ) + } + + struct DummyActor; + + #[ractor::async_trait] + impl Actor for DummyActor { + type Msg = (); + type State = (); + type Arguments = (); + + async fn pre_start( + &self, + _myself: ActorRef, + _args: Self::Arguments, + ) -> Result { + Ok(()) + } + } + + async fn spawn_name_collision_err(name: &str) -> SpawnErr { + Actor::spawn(Some(name.to_string()), DummyActor, ()) + .await + .expect_err("expected name collision to produce SpawnErr") + } + + #[tokio::test] + async fn zero_attempts_never_calls_spawn() { + let strategy = RetryStrategy { + max_attempts: 0, + base_delay: Duration::from_millis(5), + }; + let calls = Arc::new(AtomicU64::new(0)); + let calls2 = calls.clone(); + + let result = spawn_with_retry(&strategy, move || { + calls2.fetch_add(1, Ordering::SeqCst); + async { + panic!("spawn closure must not be called when max_attempts is zero"); + #[allow(unreachable_code)] + Err(spawn_name_collision_err("unused").await) + } + }) + .await; + + assert!(result.is_none()); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn retries_until_success() { + let keep_name = unique_name("retry_keep"); + let (keep_ref, keep_handle) = Actor::spawn(Some(keep_name.clone()), DummyActor, ()) + .await + .expect("failed to spawn keeper actor"); + + let success_name = unique_name("retry_success"); + let attempts = Arc::new(AtomicU64::new(0)); + let attempts2 = attempts.clone(); + + let strategy = RetryStrategy { + max_attempts: 4, + base_delay: Duration::from_millis(10), + }; + + let start = Instant::now(); + let result = spawn_with_retry(&strategy, move || { + let keep_name = keep_name.clone(); + let success_name = success_name.clone(); + let attempts = attempts2.clone(); + async move { + let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1; + if attempt < 3 { + Err(spawn_name_collision_err(&keep_name).await) + } else { + let (ok_ref, _ok_handle) = + Actor::spawn(Some(success_name), DummyActor, ()).await?; + Ok(ok_ref.get_cell()) + } + } + }) + .await; + + let elapsed = start.elapsed(); + assert!(result.is_some()); + assert_eq!(attempts.load(Ordering::SeqCst), 3); + assert!( + elapsed >= Duration::from_millis(60), + "expected backoff delays before success, got {elapsed:?}" + ); + + if let Some(cell) = result { + cell.kill(); + } + keep_ref.stop(None); + let _ = keep_handle.await; + } + + #[tokio::test] + async fn returns_none_after_exhausting_attempts() { + let keep_name = unique_name("retry_keep_fail"); + let (keep_ref, keep_handle) = Actor::spawn(Some(keep_name.clone()), DummyActor, ()) + .await + .expect("failed to spawn keeper actor"); + + let attempts = Arc::new(AtomicU64::new(0)); + let attempts2 = attempts.clone(); + let strategy = RetryStrategy { + max_attempts: 3, + base_delay: Duration::from_millis(10), + }; + + let result = spawn_with_retry(&strategy, move || { + let keep_name = keep_name.clone(); + let attempts = attempts2.clone(); + async move { + attempts.fetch_add(1, Ordering::SeqCst); + Err(spawn_name_collision_err(&keep_name).await) + } + }) + .await; + + assert!(result.is_none()); + assert_eq!(attempts.load(Ordering::SeqCst), 3); + + keep_ref.stop(None); + let _ = keep_handle.await; + } +} diff --git a/crates/supervisor/src/supervisor.rs b/crates/supervisor/src/supervisor.rs new file mode 100644 index 0000000000..9409562d75 --- /dev/null +++ b/crates/supervisor/src/supervisor.rs @@ -0,0 +1,615 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SpawnErr, SupervisionEvent}; + +use crate::restart::{RestartBudget, RestartTracker}; +use crate::retry::{RetryStrategy, spawn_with_retry}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RestartPolicy { + Permanent, + Transient, + Temporary, +} + +pub type SpawnFuture = Pin> + Send>>; + +#[derive(Clone)] +pub struct SpawnFn(Arc SpawnFuture + Send + Sync>); + +impl SpawnFn { + pub fn new(f: F) -> Self + where + F: Fn(ActorCell) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + Self(Arc::new(move |cell| Box::pin(f(cell)))) + } + + pub async fn call(&self, supervisor_cell: ActorCell) -> Result { + (self.0)(supervisor_cell).await + } +} + +pub struct ChildSpec { + pub id: String, + pub restart_policy: RestartPolicy, + pub spawn_fn: SpawnFn, +} + +pub struct SupervisorConfig { + pub children: Vec, + pub restart_budget: RestartBudget, + pub retry_strategy: RetryStrategy, +} + +struct ChildEntry { + id: String, + cell: Option, + policy: RestartPolicy, + spawn_fn: SpawnFn, + tracker: RestartTracker, +} + +pub struct SupervisorState { + children: Vec, + budget: RestartBudget, + retry_strategy: RetryStrategy, + shutting_down: bool, +} + +pub struct Supervisor; + +#[derive(Debug)] +pub enum SupervisorMsg { + Shutdown, +} + +impl SupervisorState { + fn find_child_index(&self, cell: &ActorCell) -> Option { + self.children + .iter() + .position(|e| e.cell.as_ref().is_some_and(|c| c.get_id() == cell.get_id())) + } + + fn stop_all_children(&mut self) { + for entry in &mut self.children { + if let Some(cell) = entry.cell.take() { + cell.stop(Some("supervisor_shutdown".to_string())); + } + } + } +} + +fn should_restart(policy: RestartPolicy, abnormal: bool) -> bool { + match policy { + RestartPolicy::Permanent => true, + RestartPolicy::Transient => abnormal, + RestartPolicy::Temporary => false, + } +} + +async fn handle_child_exit( + myself: &ActorRef, + state: &mut SupervisorState, + idx: usize, + abnormal: bool, +) { + state.children[idx].cell = None; + + if !should_restart(state.children[idx].policy, abnormal) { + return; + } + + if !state.children[idx].tracker.record_restart(&state.budget) { + tracing::error!(child = %state.children[idx].id, "restart_limit_exceeded"); + state.shutting_down = true; + state.stop_all_children(); + myself.stop(Some("meltdown".to_string())); + return; + } + + let spawn_fn = state.children[idx].spawn_fn.clone(); + let sup_cell = myself.get_cell(); + let retry = state.retry_strategy.clone(); + + let new_cell = spawn_with_retry(&retry, || { + let sup = sup_cell.clone(); + let sf = spawn_fn.clone(); + async move { sf.call(sup).await } + }) + .await; + + match new_cell { + Some(cell) => { + state.children[idx].cell = Some(cell); + } + None => { + tracing::error!(child = %state.children[idx].id, "spawn_retry_exhausted"); + state.shutting_down = true; + state.stop_all_children(); + myself.stop(Some("spawn_retry_exhausted".to_string())); + } + } +} + +#[ractor::async_trait] +impl Actor for Supervisor { + type Msg = SupervisorMsg; + type State = SupervisorState; + type Arguments = SupervisorConfig; + + async fn pre_start( + &self, + myself: ActorRef, + config: Self::Arguments, + ) -> Result { + let mut children = Vec::new(); + + for spec in config.children { + let cell = + spec.spawn_fn + .call(myself.get_cell()) + .await + .map_err(|e| -> ActorProcessingErr { + format!("failed to spawn child '{}': {}", spec.id, e).into() + })?; + + children.push(ChildEntry { + id: spec.id, + cell: Some(cell), + policy: spec.restart_policy, + spawn_fn: spec.spawn_fn, + tracker: RestartTracker::new(), + }); + } + + Ok(SupervisorState { + children, + budget: config.restart_budget, + retry_strategy: config.retry_strategy, + shutting_down: false, + }) + } + + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + SupervisorMsg::Shutdown => { + state.shutting_down = true; + state.stop_all_children(); + myself.stop(None); + } + } + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + message: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + for child in &mut state.children { + child.tracker.maybe_reset(&state.budget); + } + + if state.shutting_down { + return Ok(()); + } + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, _reason) => { + if let Some(idx) = state.find_child_index(&cell) { + handle_child_exit(&myself, state, idx, false).await; + } + } + + SupervisionEvent::ActorFailed(cell, _error) => { + if let Some(idx) = state.find_child_index(&cell) { + handle_child_exit(&myself, state, idx, true).await; + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ractor::{Actor, ActorRef, ActorStatus}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::time::Duration; + + // ---- Test child actor with configurable behaviors ---- + + #[derive(Clone)] + enum ChildBehavior { + Healthy, + DelayedFail { ms: u64 }, + DelayedNormal { ms: u64 }, + } + + struct TestChild { + counter: Arc, + } + + #[ractor::async_trait] + impl Actor for TestChild { + type Msg = (); + type State = ChildBehavior; + type Arguments = ChildBehavior; + + async fn pre_start( + &self, + myself: ActorRef, + behavior: Self::Arguments, + ) -> Result { + self.counter.fetch_add(1, Ordering::SeqCst); + + match &behavior { + ChildBehavior::Healthy => {} + ChildBehavior::DelayedFail { ms } => { + myself.send_after(Duration::from_millis(*ms), || ()); + } + ChildBehavior::DelayedNormal { ms } => { + myself.send_after(Duration::from_millis(*ms), || ()); + } + } + Ok(behavior) + } + + async fn handle( + &self, + myself: ActorRef, + _msg: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match state { + ChildBehavior::DelayedFail { .. } => panic!("delayed_fail"), + ChildBehavior::DelayedNormal { .. } => myself.stop(None), + _ => {} + } + Ok(()) + } + } + + // ---- Helpers ---- + + fn make_child_spec( + name: &str, + policy: RestartPolicy, + behavior: ChildBehavior, + counter: Arc, + ) -> ChildSpec { + let name = name.to_string(); + ChildSpec { + id: name.clone(), + restart_policy: policy, + spawn_fn: SpawnFn::new(move |sup_cell| { + let behavior = behavior.clone(); + let counter = counter.clone(); + let name = name.clone(); + async move { + let (actor_ref, _) = + Actor::spawn_linked(Some(name), TestChild { counter }, behavior, sup_cell) + .await?; + Ok(actor_ref.get_cell()) + } + }), + } + } + + fn test_budget(max_restarts: u32) -> RestartBudget { + RestartBudget { + max_restarts, + max_window: Duration::from_secs(10), + reset_after: None, + } + } + + fn fast_retry() -> RetryStrategy { + RetryStrategy { + max_attempts: 3, + base_delay: Duration::from_millis(20), + } + } + + // ---- Tests ---- + + #[tokio::test] + async fn permanent_child_restarts_on_failure() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "perm_restart_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(1), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_perm_restart".to_string()), Supervisor, config) + .await + .unwrap(); + + // Child fails once -> restarted -> fails again -> meltdown + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + // initial + 1 restart = 2 spawns + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn transient_no_restart_on_normal_exit() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "trans_normal_child", + RestartPolicy::Transient, + ChildBehavior::DelayedNormal { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_trans_normal".to_string()), Supervisor, config) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn transient_restarts_on_failure() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "trans_fail_child", + RestartPolicy::Transient, + ChildBehavior::DelayedFail { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(1), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_trans_fail".to_string()), Supervisor, config) + .await + .unwrap(); + + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn temporary_never_restarts() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "temp_child", + RestartPolicy::Temporary, + ChildBehavior::DelayedFail { ms: 100 }, + counter.clone(), + )], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_temp_never".to_string()), Supervisor, config) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn meltdown_on_budget_exceeded() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "meltdown_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 50 }, + counter.clone(), + )], + restart_budget: test_budget(2), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_meltdown".to_string()), Supervisor, config) + .await + .unwrap(); + + let _ = sup_handle.await; + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + // initial + 2 restarts = 3 spawns, then meltdown on 3rd failure + assert_eq!(counter.load(Ordering::SeqCst), 3); + } + + #[tokio::test] + async fn reset_after_quiet_period() { + let counter = Arc::new(AtomicU32::new(0)); + + // Budget: 1 restart in 10s window, but reset_after 200ms + let budget = RestartBudget { + max_restarts: 1, + max_window: Duration::from_secs(10), + reset_after: Some(Duration::from_millis(200)), + }; + + let config = SupervisorConfig { + children: vec![make_child_spec( + "reset_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 300 }, + counter.clone(), + )], + restart_budget: budget, + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_reset_quiet".to_string()), Supervisor, config) + .await + .unwrap(); + + // Child fails at ~300ms, restarted (count=1). + // Fails again at ~600ms. reset_after=200ms < 300ms gap => counter resets => count=1 again, no meltdown. + // Let it run through a few cycles. + tokio::time::sleep(Duration::from_millis(1500)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!(counter.load(Ordering::SeqCst) >= 3); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn shutdown_suppresses_restarts() { + let counter = Arc::new(AtomicU32::new(0)); + let config = SupervisorConfig { + children: vec![make_child_spec( + "shutdown_child", + RestartPolicy::Permanent, + ChildBehavior::Healthy, + counter.clone(), + )], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = Actor::spawn( + Some("test_shutdown_suppress".to_string()), + Supervisor, + config, + ) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + + let _ = sup_ref.cast(SupervisorMsg::Shutdown); + let _ = sup_handle.await; + + assert_eq!(sup_ref.get_status(), ActorStatus::Stopped); + // Only spawned once; shutdown should not trigger restarts + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn multiple_children_independent() { + let healthy_counter = Arc::new(AtomicU32::new(0)); + let failing_counter = Arc::new(AtomicU32::new(0)); + + let config = SupervisorConfig { + children: vec![ + make_child_spec( + "multi_healthy", + RestartPolicy::Permanent, + ChildBehavior::Healthy, + healthy_counter.clone(), + ), + make_child_spec( + "multi_temp_fail", + RestartPolicy::Temporary, + ChildBehavior::DelayedFail { ms: 100 }, + failing_counter.clone(), + ), + ], + restart_budget: test_budget(5), + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_multi_indep".to_string()), Supervisor, config) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + + // Healthy child spawned once and still running + assert_eq!(healthy_counter.load(Ordering::SeqCst), 1); + // Temporary child failed once, not restarted + assert_eq!(failing_counter.load(Ordering::SeqCst), 1); + + // Supervisor still has the healthy child + let running: Vec<_> = sup_ref + .get_children() + .into_iter() + .filter(|c| c.get_status() == ActorStatus::Running) + .collect(); + assert_eq!(running.len(), 1); + + sup_ref.stop(None); + let _ = sup_handle.await; + } + + #[tokio::test] + async fn window_expiry_allows_more_restarts() { + let counter = Arc::new(AtomicU32::new(0)); + + // Budget: 1 restart in a 200ms window (no reset_after) + let budget = RestartBudget { + max_restarts: 1, + max_window: Duration::from_millis(200), + reset_after: None, + }; + + let config = SupervisorConfig { + children: vec![make_child_spec( + "window_child", + RestartPolicy::Permanent, + ChildBehavior::DelayedFail { ms: 300 }, + counter.clone(), + )], + restart_budget: budget, + retry_strategy: fast_retry(), + }; + + let (sup_ref, sup_handle) = + Actor::spawn(Some("test_window_expiry".to_string()), Supervisor, config) + .await + .unwrap(); + + // Each child lives ~300ms, then fails. Window is 200ms, so each failure starts a new window. + // With budget=1 per window, each failure is restart #1 in a fresh window => no meltdown. + tokio::time::sleep(Duration::from_millis(1200)).await; + assert_eq!(sup_ref.get_status(), ActorStatus::Running); + assert!(counter.load(Ordering::SeqCst) >= 3); + + sup_ref.stop(None); + let _ = sup_handle.await; + } +} diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 6560151f92..6dd3e9f5d9 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -63,7 +63,8 @@ uuid = { workspace = true, features = ["v4"] } hound = { workspace = true } vorbis_rs = { workspace = true } -ractor = { workspace = true } +hypr-supervisor = { workspace = true } +ractor = { workspace = true, features = ["async-trait"] } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index a37b6d3364..a2bae4c8c5 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -109,9 +109,10 @@ sessionProgressEvent: "plugin:listener:session-progress-event" /** user-defined types **/ +export type DegradedError = { type: "authentication_failed"; provider: string } | { type: "upstream_unavailable"; message: string } | { type: "connection_timeout" } | { type: "stream_error"; message: string } | { type: "channel_overflow" } export type SessionDataEvent = { type: "audio_amplitude"; session_id: string; mic: number; speaker: number } | { type: "mic_muted"; session_id: string; value: boolean } | { type: "stream_response"; session_id: string; response: StreamResponse } export type SessionErrorEvent = { type: "audio_error"; session_id: string; error: string; device: string | null; is_fatal: boolean } | { type: "connection_error"; session_id: string; error: string } -export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string } | { type: "finalizing"; session_id: string } +export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string; error?: DegradedError | null } | { type: "finalizing"; session_id: string } export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } export type SessionProgressEvent = { type: "audio_initializing"; session_id: string } | { type: "audio_ready"; session_id: string; device: string | null } | { type: "connecting"; session_id: string } | { type: "connected"; session_id: string; adapter: string } export type State = "active" | "inactive" | "finalizing" diff --git a/plugins/listener/src/actors/root.rs b/plugins/listener/src/actors/root.rs index 17abd6e0c4..81819bf1a5 100644 --- a/plugins/listener/src/actors/root.rs +++ b/plugins/listener/src/actors/root.rs @@ -5,7 +5,6 @@ use tauri_plugin_settings::SettingsPluginExt; use tauri_specta::Event; use tracing::Instrument; -use crate::DegradedError; use crate::SessionLifecycleEvent; use crate::actors::session::lifecycle::{ clear_sentry_session_context, configure_sentry_session_context, emit_session_ended, @@ -106,13 +105,7 @@ impl Actor for RootActor { state.supervisor = None; state.finalizing = false; - let failure_reason = reason.and_then(|r| { - serde_json::from_str::(&r) - .ok() - .map(|d| format!("{:?}", d)) - .or(Some(r)) - }); - emit_session_ended(&state.app, &session_id, failure_reason); + emit_session_ended(&state.app, &session_id, reason); } } SupervisionEvent::ActorFailed(cell, error) => { @@ -223,6 +216,12 @@ async fn stop_session_impl(state: &mut RootState) { } let session_ref: ActorRef = supervisor.clone().into(); - let _ = session_ref.cast(SessionMsg::Shutdown); + if let Err(error) = session_ref.cast(SessionMsg::Shutdown) { + tracing::warn!( + ?error, + "failed_to_cast_session_shutdown_falling_back_to_stop" + ); + supervisor.stop(Some("session_stop_cast_failed".to_string())); + } } } diff --git a/plugins/listener/src/actors/session/mod.rs b/plugins/listener/src/actors/session/mod.rs index 1889f43b59..0610875edc 100644 --- a/plugins/listener/src/actors/session/mod.rs +++ b/plugins/listener/src/actors/session/mod.rs @@ -1,446 +1,6 @@ pub(crate) mod lifecycle; +mod supervisor; +mod types; -use std::path::PathBuf; -use std::time::{Instant, SystemTime}; - -use ractor::concurrency::Duration; -use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SupervisionEvent}; -use tauri_specta::Event; -use tracing::Instrument; - -use crate::DegradedError; -use crate::SessionLifecycleEvent; -use crate::actors::{ - ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, -}; - -pub const SESSION_SUPERVISOR_PREFIX: &str = "session_supervisor_"; - -/// Creates a tracing span with session context that child events will inherit -pub(crate) fn session_span(session_id: &str) -> tracing::Span { - tracing::info_span!("session", session_id = %session_id) -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] -pub struct SessionParams { - pub session_id: String, - pub languages: Vec, - pub onboarding: bool, - pub record_enabled: bool, - pub model: String, - pub base_url: String, - pub api_key: String, - pub keywords: Vec, -} - -#[derive(Clone)] -pub struct SessionContext { - pub app: tauri::AppHandle, - pub params: SessionParams, - pub app_dir: PathBuf, - pub started_at_instant: Instant, - pub started_at_system: SystemTime, -} - -pub fn session_supervisor_name(session_id: &str) -> String { - format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ChildKind { - Source, - Listener, - Recorder, -} - -struct RestartTracker { - count: u32, - window_start: Instant, -} - -impl RestartTracker { - fn new() -> Self { - Self { - count: 0, - window_start: Instant::now(), - } - } - - fn record_restart(&mut self, max_restarts: u32, max_window: Duration) -> bool { - let now = Instant::now(); - if now.duration_since(self.window_start) > max_window { - self.count = 0; - self.window_start = now; - } - self.count += 1; - self.count <= max_restarts - } - - fn maybe_reset(&mut self, reset_after: Duration) { - let now = Instant::now(); - if now.duration_since(self.window_start) > reset_after { - self.count = 0; - self.window_start = now; - } - } -} - -pub struct SessionState { - ctx: SessionContext, - source_cell: Option, - listener_cell: Option, - recorder_cell: Option, - listener_degraded: Option, - source_restarts: RestartTracker, - recorder_restarts: RestartTracker, - shutting_down: bool, -} - -pub struct SessionActor; - -impl SessionActor { - const MAX_RESTARTS: u32 = 3; - const MAX_WINDOW: Duration = Duration::from_secs(15); - const RESET_AFTER: Duration = Duration::from_secs(30); -} - -#[derive(Debug)] -pub enum SessionMsg { - Shutdown, -} - -#[ractor::async_trait] -impl Actor for SessionActor { - type Msg = SessionMsg; - type State = SessionState; - type Arguments = SessionContext; - - async fn pre_start( - &self, - myself: ActorRef, - ctx: Self::Arguments, - ) -> Result { - let session_id = ctx.params.session_id.clone(); - let span = session_span(&session_id); - - async { - let (source_ref, _) = Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - mic_device: None, - onboarding: ctx.params.onboarding, - app: ctx.app.clone(), - session_id: ctx.params.session_id.clone(), - }, - myself.get_cell(), - ) - .await?; - - let mode = ChannelMode::determine(ctx.params.onboarding); - let (listener_ref, _) = Actor::spawn_linked( - Some(ListenerActor::name()), - ListenerActor, - ListenerArgs { - app: ctx.app.clone(), - languages: ctx.params.languages.clone(), - onboarding: ctx.params.onboarding, - model: ctx.params.model.clone(), - base_url: ctx.params.base_url.clone(), - api_key: ctx.params.api_key.clone(), - keywords: ctx.params.keywords.clone(), - mode, - session_started_at: ctx.started_at_instant, - session_started_at_unix: ctx.started_at_system, - session_id: ctx.params.session_id.clone(), - }, - myself.get_cell(), - ) - .await?; - - let recorder_cell = if ctx.params.record_enabled { - let (recorder_ref, _) = Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: ctx.app_dir.clone(), - session_id: ctx.params.session_id.clone(), - }, - myself.get_cell(), - ) - .await?; - Some(recorder_ref.get_cell()) - } else { - None - }; - - Ok(SessionState { - ctx, - source_cell: Some(source_ref.get_cell()), - listener_cell: Some(listener_ref.get_cell()), - recorder_cell, - listener_degraded: None, - source_restarts: RestartTracker::new(), - recorder_restarts: RestartTracker::new(), - shutting_down: false, - }) - } - .instrument(span) - .await - } - - async fn handle( - &self, - myself: ActorRef, - message: Self::Msg, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match message { - SessionMsg::Shutdown => { - state.shutting_down = true; - - if let Some(cell) = state.recorder_cell.take() { - cell.stop(Some("session_stop".to_string())); - lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; - } - - if let Some(cell) = state.source_cell.take() { - cell.stop(Some("session_stop".to_string())); - } - if let Some(cell) = state.listener_cell.take() { - cell.stop(Some("session_stop".to_string())); - } - - myself.stop(None); - } - } - Ok(()) - } - - async fn handle_supervisor_evt( - &self, - myself: ActorRef, - message: SupervisionEvent, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - let span = session_span(&state.ctx.params.session_id); - let _guard = span.enter(); - - state.source_restarts.maybe_reset(Self::RESET_AFTER); - state.recorder_restarts.maybe_reset(Self::RESET_AFTER); - - if state.shutting_down { - return Ok(()); - } - - match message { - SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} - - SupervisionEvent::ActorTerminated(cell, _, reason) => { - match identify_child(state, &cell) { - Some(ChildKind::Listener) => { - tracing::info!(?reason, "listener_terminated_entering_degraded_mode"); - let degraded = reason - .as_ref() - .and_then(|r| serde_json::from_str::(r).ok()); - state.listener_degraded = degraded.clone(); - state.listener_cell = None; - - let _ = (SessionLifecycleEvent::Active { - session_id: state.ctx.params.session_id.clone(), - error: degraded, - }) - .emit(&state.ctx.app); - } - Some(ChildKind::Source) => { - tracing::info!(?reason, "source_terminated_attempting_restart"); - state.source_cell = None; - if !try_restart_source(myself.get_cell(), state).await { - tracing::error!("source_restart_limit_exceeded_meltdown"); - meltdown(myself, state).await; - } - } - Some(ChildKind::Recorder) => { - tracing::info!(?reason, "recorder_terminated_attempting_restart"); - state.recorder_cell = None; - if !try_restart_recorder(myself.get_cell(), state).await { - tracing::error!("recorder_restart_limit_exceeded_meltdown"); - meltdown(myself, state).await; - } - } - None => { - tracing::warn!("unknown_child_terminated"); - } - } - } - - SupervisionEvent::ActorFailed(cell, error) => match identify_child(state, &cell) { - Some(ChildKind::Listener) => { - tracing::info!(?error, "listener_failed_entering_degraded_mode"); - let degraded = DegradedError::StreamError { - message: format!("{:?}", error), - }; - state.listener_degraded = Some(degraded.clone()); - state.listener_cell = None; - - let _ = (SessionLifecycleEvent::Active { - session_id: state.ctx.params.session_id.clone(), - error: Some(degraded), - }) - .emit(&state.ctx.app); - } - Some(ChildKind::Source) => { - tracing::warn!(?error, "source_failed_attempting_restart"); - state.source_cell = None; - if !try_restart_source(myself.get_cell(), state).await { - tracing::error!("source_restart_limit_exceeded_meltdown"); - meltdown(myself, state).await; - } - } - Some(ChildKind::Recorder) => { - tracing::warn!(?error, "recorder_failed_attempting_restart"); - state.recorder_cell = None; - if !try_restart_recorder(myself.get_cell(), state).await { - tracing::error!("recorder_restart_limit_exceeded_meltdown"); - meltdown(myself, state).await; - } - } - None => { - tracing::warn!("unknown_child_failed"); - } - }, - } - Ok(()) - } -} - -fn identify_child(state: &SessionState, cell: &ActorCell) -> Option { - if state - .source_cell - .as_ref() - .is_some_and(|c| c.get_id() == cell.get_id()) - { - return Some(ChildKind::Source); - } - if state - .listener_cell - .as_ref() - .is_some_and(|c| c.get_id() == cell.get_id()) - { - return Some(ChildKind::Listener); - } - if state - .recorder_cell - .as_ref() - .is_some_and(|c| c.get_id() == cell.get_id()) - { - return Some(ChildKind::Recorder); - } - None -} - -async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { - if !state - .source_restarts - .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) - { - return false; - } - - for attempt in 0..3u32 { - let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt)); - tokio::time::sleep(delay).await; - - match Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - mic_device: None, - onboarding: state.ctx.params.onboarding, - app: state.ctx.app.clone(), - session_id: state.ctx.params.session_id.clone(), - }, - supervisor_cell.clone(), - ) - .await - { - Ok((actor_ref, _)) => { - state.source_cell = Some(actor_ref.get_cell()); - tracing::info!(attempt, "source_restarted"); - return true; - } - Err(e) => { - tracing::warn!(attempt, error = ?e, "source_spawn_attempt_failed"); - } - } - } - - tracing::error!("source_restart_failed_all_attempts"); - false -} - -async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { - if !state.ctx.params.record_enabled { - return true; - } - - if !state - .recorder_restarts - .record_restart(SessionActor::MAX_RESTARTS, SessionActor::MAX_WINDOW) - { - return false; - } - - for attempt in 0..3u32 { - let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt)); - tokio::time::sleep(delay).await; - - match Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: state.ctx.app_dir.clone(), - session_id: state.ctx.params.session_id.clone(), - }, - supervisor_cell.clone(), - ) - .await - { - Ok((actor_ref, _)) => { - state.recorder_cell = Some(actor_ref.get_cell()); - tracing::info!(attempt, "recorder_restarted"); - return true; - } - Err(e) => { - tracing::warn!(attempt, error = ?e, "recorder_spawn_attempt_failed"); - } - } - } - - tracing::error!("recorder_restart_failed_all_attempts"); - false -} - -async fn meltdown(myself: ActorRef, state: &mut SessionState) { - if let Some(cell) = state.source_cell.take() { - cell.stop(Some("meltdown".to_string())); - } - if let Some(cell) = state.listener_cell.take() { - cell.stop(Some("meltdown".to_string())); - } - if let Some(cell) = state.recorder_cell.take() { - cell.stop(Some("meltdown".to_string())); - lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; - } - myself.stop(Some("restart_limit_exceeded".to_string())); -} - -pub async fn spawn_session_supervisor( - ctx: SessionContext, -) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { - let supervisor_name = session_supervisor_name(&ctx.params.session_id); - - let (actor_ref, handle) = Actor::spawn(Some(supervisor_name), SessionActor, ctx).await?; - - Ok((actor_ref.get_cell(), handle)) -} +pub use supervisor::*; +pub use types::*; diff --git a/plugins/listener/src/actors/session/supervisor.rs b/plugins/listener/src/actors/session/supervisor.rs new file mode 100644 index 0000000000..ca55e18d39 --- /dev/null +++ b/plugins/listener/src/actors/session/supervisor.rs @@ -0,0 +1,419 @@ +use hypr_supervisor::{RestartBudget, RestartTracker, RetryStrategy, spawn_with_retry}; +use ractor::concurrency::Duration; +use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SupervisionEvent}; +use tauri_specta::Event; +use tracing::Instrument; + +use crate::DegradedError; +use crate::SessionLifecycleEvent; +use crate::actors::session::lifecycle; +use crate::actors::session::types::{SessionContext, session_span, session_supervisor_name}; +use crate::actors::{ + ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ChildKind { + Source, + Listener, + Recorder, +} + +const RESTART_BUDGET: RestartBudget = RestartBudget { + max_restarts: 3, + max_window: Duration::from_secs(15), + reset_after: Some(Duration::from_secs(30)), +}; + +const RETRY_STRATEGY: RetryStrategy = RetryStrategy { + max_attempts: 3, + base_delay: Duration::from_millis(100), +}; + +pub struct SessionState { + ctx: SessionContext, + source_cell: Option, + listener_cell: Option, + recorder_cell: Option, + listener_degraded: Option, + source_restarts: RestartTracker, + recorder_restarts: RestartTracker, + shutting_down: bool, +} + +pub struct SessionActor; + +#[derive(Debug)] +pub enum SessionMsg { + Shutdown, +} + +#[ractor::async_trait] +impl Actor for SessionActor { + type Msg = SessionMsg; + type State = SessionState; + type Arguments = SessionContext; + + async fn pre_start( + &self, + myself: ActorRef, + ctx: Self::Arguments, + ) -> Result { + let session_id = ctx.params.session_id.clone(); + let span = session_span(&session_id); + + async { + let (source_ref, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: ctx.params.onboarding, + app: ctx.app.clone(), + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + + let mode = ChannelMode::determine(ctx.params.onboarding); + let (listener_ref, _) = Actor::spawn_linked( + Some(ListenerActor::name()), + ListenerActor, + ListenerArgs { + app: ctx.app.clone(), + languages: ctx.params.languages.clone(), + onboarding: ctx.params.onboarding, + model: ctx.params.model.clone(), + base_url: ctx.params.base_url.clone(), + api_key: ctx.params.api_key.clone(), + keywords: ctx.params.keywords.clone(), + mode, + session_started_at: ctx.started_at_instant, + session_started_at_unix: ctx.started_at_system, + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + + let recorder_cell = if ctx.params.record_enabled { + let (recorder_ref, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: ctx.app_dir.clone(), + session_id: ctx.params.session_id.clone(), + }, + myself.get_cell(), + ) + .await?; + Some(recorder_ref.get_cell()) + } else { + None + }; + + Ok(SessionState { + ctx, + source_cell: Some(source_ref.get_cell()), + listener_cell: Some(listener_ref.get_cell()), + recorder_cell, + listener_degraded: None, + source_restarts: RestartTracker::new(), + recorder_restarts: RestartTracker::new(), + shutting_down: false, + }) + } + .instrument(span) + .await + } + + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + SessionMsg::Shutdown => { + state.shutting_down = true; + + if let Some(cell) = state.recorder_cell.take() { + cell.stop(Some("session_stop".to_string())); + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + } + + if let Some(cell) = state.source_cell.take() { + cell.stop(Some("session_stop".to_string())); + } + if let Some(cell) = state.listener_cell.take() { + cell.stop(Some("session_stop".to_string())); + } + + myself.stop(None); + } + } + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + myself: ActorRef, + message: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + let span = session_span(&state.ctx.params.session_id); + let _guard = span.enter(); + + state.source_restarts.maybe_reset(&RESTART_BUDGET); + state.recorder_restarts.maybe_reset(&RESTART_BUDGET); + + if state.shutting_down { + return Ok(()); + } + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + + SupervisionEvent::ActorTerminated(cell, _, reason) => { + match identify_child(state, &cell) { + Some(ChildKind::Listener) => { + tracing::info!(?reason, "listener_terminated_entering_degraded_mode"); + let degraded = parse_degraded_reason(reason.as_ref()); + state.listener_degraded = Some(degraded.clone()); + state.listener_cell = None; + + let _ = (SessionLifecycleEvent::Active { + session_id: state.ctx.params.session_id.clone(), + error: Some(degraded), + }) + .emit(&state.ctx.app); + } + Some(ChildKind::Source) => { + tracing::info!(?reason, "source_terminated_attempting_restart"); + state.source_cell = None; + if !try_restart_source(myself.get_cell(), state).await { + tracing::error!("source_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + Some(ChildKind::Recorder) => { + tracing::info!(?reason, "recorder_terminated_attempting_restart"); + state.recorder_cell = None; + if !try_restart_recorder(myself.get_cell(), state).await { + tracing::error!("recorder_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + None => { + tracing::warn!("unknown_child_terminated"); + } + } + } + + SupervisionEvent::ActorFailed(cell, error) => match identify_child(state, &cell) { + Some(ChildKind::Listener) => { + tracing::info!(?error, "listener_failed_entering_degraded_mode"); + let degraded = DegradedError::StreamError { + message: format!("{:?}", error), + }; + state.listener_degraded = Some(degraded.clone()); + state.listener_cell = None; + + let _ = (SessionLifecycleEvent::Active { + session_id: state.ctx.params.session_id.clone(), + error: Some(degraded), + }) + .emit(&state.ctx.app); + } + Some(ChildKind::Source) => { + tracing::warn!(?error, "source_failed_attempting_restart"); + state.source_cell = None; + if !try_restart_source(myself.get_cell(), state).await { + tracing::error!("source_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + Some(ChildKind::Recorder) => { + tracing::warn!(?error, "recorder_failed_attempting_restart"); + state.recorder_cell = None; + if !try_restart_recorder(myself.get_cell(), state).await { + tracing::error!("recorder_restart_limit_exceeded_meltdown"); + meltdown(myself, state).await; + } + } + None => { + tracing::warn!("unknown_child_failed"); + } + }, + } + Ok(()) + } +} + +fn identify_child(state: &SessionState, cell: &ActorCell) -> Option { + if state + .source_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Source); + } + if state + .listener_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Listener); + } + if state + .recorder_cell + .as_ref() + .is_some_and(|c| c.get_id() == cell.get_id()) + { + return Some(ChildKind::Recorder); + } + None +} + +async fn try_restart_source(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state.source_restarts.record_restart(&RESTART_BUDGET) { + return false; + } + + let sup = supervisor_cell; + let onboarding = state.ctx.params.onboarding; + let app = state.ctx.app.clone(); + let session_id = state.ctx.params.session_id.clone(); + + let cell = spawn_with_retry(&RETRY_STRATEGY, || { + let sup = sup.clone(); + let app = app.clone(); + let session_id = session_id.clone(); + async move { + let (r, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding, + app, + session_id, + }, + sup, + ) + .await?; + Ok(r.get_cell()) + } + }) + .await; + + match cell { + Some(c) => { + state.source_cell = Some(c); + true + } + None => false, + } +} + +async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionState) -> bool { + if !state.ctx.params.record_enabled { + return true; + } + + if !state.recorder_restarts.record_restart(&RESTART_BUDGET) { + return false; + } + + let sup = supervisor_cell; + let app_dir = state.ctx.app_dir.clone(); + let session_id = state.ctx.params.session_id.clone(); + + let cell = spawn_with_retry(&RETRY_STRATEGY, || { + let sup = sup.clone(); + let app_dir = app_dir.clone(); + let session_id = session_id.clone(); + async move { + let (r, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir, + session_id, + }, + sup, + ) + .await?; + Ok(r.get_cell()) + } + }) + .await; + + match cell { + Some(c) => { + state.recorder_cell = Some(c); + true + } + None => false, + } +} + +async fn meltdown(myself: ActorRef, state: &mut SessionState) { + state.shutting_down = true; + + if let Some(cell) = state.source_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + if let Some(cell) = state.listener_cell.take() { + cell.stop(Some("meltdown".to_string())); + } + if let Some(cell) = state.recorder_cell.take() { + cell.stop(Some("meltdown".to_string())); + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + } + myself.stop(Some("restart_limit_exceeded".to_string())); +} + +fn parse_degraded_reason(reason: Option<&String>) -> DegradedError { + reason + .and_then(|r| serde_json::from_str::(r).ok()) + .unwrap_or_else(|| DegradedError::StreamError { + message: reason + .cloned() + .unwrap_or_else(|| "listener terminated without reason".to_string()), + }) +} + +pub async fn spawn_session_supervisor( + ctx: SessionContext, +) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { + let supervisor_name = session_supervisor_name(&ctx.params.session_id); + let (actor_ref, handle) = Actor::spawn(Some(supervisor_name), SessionActor, ctx).await?; + Ok((actor_ref.get_cell(), handle)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_degraded_reason_uses_json_payload() { + let reason = serde_json::to_string(&DegradedError::ConnectionTimeout).unwrap(); + let parsed = parse_degraded_reason(Some(&reason)); + assert!(matches!(parsed, DegradedError::ConnectionTimeout)); + } + + #[test] + fn parse_degraded_reason_falls_back_for_missing_reason() { + let parsed = parse_degraded_reason(None); + assert!(matches!(parsed, DegradedError::StreamError { .. })); + } + + #[test] + fn parse_degraded_reason_falls_back_for_invalid_json() { + let reason = "not-json".to_string(); + let parsed = parse_degraded_reason(Some(&reason)); + assert!(matches!(parsed, DegradedError::StreamError { .. })); + } +} diff --git a/plugins/listener/src/actors/session/types.rs b/plugins/listener/src/actors/session/types.rs new file mode 100644 index 0000000000..9bf5df0b12 --- /dev/null +++ b/plugins/listener/src/actors/session/types.rs @@ -0,0 +1,33 @@ +use std::path::PathBuf; +use std::time::{Instant, SystemTime}; + +pub const SESSION_SUPERVISOR_PREFIX: &str = "session_supervisor_"; + +pub fn session_span(session_id: &str) -> tracing::Span { + tracing::info_span!("session", session_id = %session_id) +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct SessionParams { + pub session_id: String, + pub languages: Vec, + pub onboarding: bool, + pub record_enabled: bool, + pub model: String, + pub base_url: String, + pub api_key: String, + pub keywords: Vec, +} + +#[derive(Clone)] +pub struct SessionContext { + pub app: tauri::AppHandle, + pub params: SessionParams, + pub app_dir: PathBuf, + pub started_at_instant: Instant, + pub started_at_system: SystemTime, +} + +pub fn session_supervisor_name(session_id: &str) -> String { + format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) +} diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index 8ffc7888e0..7814ce41a9 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -79,8 +79,8 @@ tokio = { workspace = true, features = ["rt", "macros"] } tokio-util = { workspace = true } tracing = { workspace = true } -ractor = { workspace = true } -ractor-supervisor = { workspace = true } +hypr-supervisor = { workspace = true } +ractor = { workspace = true, features = ["async-trait"] } port-killer = "0.1.0" port_check = "0.3.0" diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 497becdbbf..727830d4c2 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -1,5 +1,5 @@ +use hypr_supervisor::dynamic::DynamicSupervisorMsg; use ractor::{ActorCell, ActorRef}; -use ractor_supervisor::dynamic::DynamicSupervisorMsg; use std::collections::HashMap; use tauri::{Manager, Wry}; use tokio_util::sync::CancellationToken; diff --git a/plugins/local-stt/src/server/supervisor.rs b/plugins/local-stt/src/server/supervisor.rs index 93cb6da32f..b89252aec2 100644 --- a/plugins/local-stt/src/server/supervisor.rs +++ b/plugins/local-stt/src/server/supervisor.rs @@ -1,8 +1,11 @@ -use ractor::{ActorCell, ActorProcessingErr, ActorRef, concurrency::Duration, registry}; -use ractor_supervisor::{ - core::{ChildBackoffFn, ChildSpec, Restart, SpawnFn, SupervisorError}, - dynamic::{DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions}, +use hypr_supervisor::{ + RestartPolicy, + dynamic::{ + ChildBackoffFn, DynChildSpec, DynSpawnFn, DynamicSupervisor, DynamicSupervisorMsg, + DynamicSupervisorOptions, SupervisorError, + }, }; +use ractor::{ActorCell, ActorProcessingErr, ActorRef, concurrency::Duration, registry}; #[cfg(feature = "whisper-cpp")] use super::internal::{InternalSTTActor, InternalSTTArgs}; @@ -59,8 +62,8 @@ pub async fn start_external_stt( } #[cfg(feature = "whisper-cpp")] -fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> ChildSpec { - let spawn_fn = SpawnFn::new(move |supervisor: ActorCell, child_id: String| { +fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> DynChildSpec { + let spawn_fn = DynSpawnFn::new(move |supervisor: ActorCell, child_id: String| { let args = args.clone(); async move { let (actor_ref, _handle) = @@ -70,10 +73,10 @@ fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> ChildSpec { } }); - ChildSpec { + DynChildSpec { id: INTERNAL_STT_ACTOR_NAME.to_string(), spawn_fn, - restart: Restart::Transient, + restart: RestartPolicy::Transient, backoff_fn: Some(ChildBackoffFn::new(|_, _, _, _| { Some(Duration::from_millis(500)) })), @@ -81,8 +84,8 @@ fn create_internal_child_spec_with_args(args: InternalSTTArgs) -> ChildSpec { } } -fn create_external_child_spec_with_args(args: ExternalSTTArgs) -> ChildSpec { - let spawn_fn = SpawnFn::new(move |supervisor: ActorCell, child_id: String| { +fn create_external_child_spec_with_args(args: ExternalSTTArgs) -> DynChildSpec { + let spawn_fn = DynSpawnFn::new(move |supervisor: ActorCell, child_id: String| { let args = args.clone(); async move { let (actor_ref, _handle) = @@ -92,10 +95,10 @@ fn create_external_child_spec_with_args(args: ExternalSTTArgs) -> ChildSpec { } }); - ChildSpec { + DynChildSpec { id: EXTERNAL_STT_ACTOR_NAME.to_string(), spawn_fn, - restart: Restart::Transient, + restart: RestartPolicy::Transient, backoff_fn: Some(ChildBackoffFn::new(|_, _, _, _| { Some(Duration::from_secs(1)) })), diff --git a/plugins/network/Cargo.toml b/plugins/network/Cargo.toml index 74201dfe0b..07108903c0 100644 --- a/plugins/network/Cargo.toml +++ b/plugins/network/Cargo.toml @@ -21,7 +21,6 @@ serde = { workspace = true } specta = { workspace = true } ractor = { workspace = true } -ractor-supervisor = { workspace = true } reqwest = { workspace = true } tracing = { workspace = true }