diff --git a/Cargo.lock b/Cargo.lock index 55473ace..778c1fae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8299,6 +8299,7 @@ dependencies = [ "tempfile", "tokio", "tokio-stream", + "tokio-util", "tracing", "tracing-opentelemetry", "tracing-subscriber", @@ -8380,6 +8381,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-stream", + "tokio-util", "toml 1.0.2+spec-1.1.0", "tracing", "zeph-index", @@ -8551,6 +8553,7 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", + "tokio-util", "toml 1.0.2+spec-1.1.0", "tracing", "url", diff --git a/Cargo.toml b/Cargo.toml index 285075fc..ff0b6199 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -122,7 +122,8 @@ mock = ["zeph-llm/mock"] [dependencies] anyhow.workspace = true -tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "sync"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "sync", "time"] } +tokio-util.workspace = true tracing.workspace = true tracing-subscriber.workspace = true opentelemetry = { workspace = true, optional = true } diff --git a/crates/zeph-core/Cargo.toml b/crates/zeph-core/Cargo.toml index eb524dbe..ce5e4348 100644 --- a/crates/zeph-core/Cargo.toml +++ b/crates/zeph-core/Cargo.toml @@ -24,6 +24,7 @@ serde = { workspace = true, features = ["derive"] } serde_json.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["fs", "macros", "rt-multi-thread", "sync", "time"] } +tokio-util.workspace = true tokio-stream.workspace = true toml.workspace = true tracing.workspace = true diff --git a/crates/zeph-core/src/agent/mod.rs b/crates/zeph-core/src/agent/mod.rs index dff9ee47..dee8b309 100644 --- a/crates/zeph-core/src/agent/mod.rs +++ b/crates/zeph-core/src/agent/mod.rs @@ -12,7 +12,10 @@ use std::collections::VecDeque; use std::path::PathBuf; use std::time::{Duration, Instant}; -use tokio::sync::{mpsc, watch}; +use std::sync::Arc; + +use tokio::sync::{Notify, mpsc, watch}; +use tokio_util::sync::CancellationToken; use zeph_llm::any::AnyProvider; use zeph_llm::provider::{LlmProvider, Message, Role}; @@ -123,6 +126,8 @@ pub struct Agent { pub(super) mcp: McpState, #[cfg(feature = "index")] pub(super) index: IndexState, + cancel_signal: Arc, + cancel_token: CancellationToken, start_time: Instant, message_queue: VecDeque, summary_provider: Option, @@ -216,6 +221,8 @@ impl Agent { cached_repo_map: None, repo_map_ttl: std::time::Duration::from_secs(300), }, + cancel_signal: Arc::new(Notify::new()), + cancel_token: CancellationToken::new(), start_time: Instant::now(), message_queue: VecDeque::new(), summary_provider: None, @@ -426,6 +433,14 @@ impl Agent { self } + /// Returns a handle that can cancel the current in-flight operation. + /// The returned `Notify` is stable across messages — callers invoke + /// `notify_waiters()` to cancel whatever operation is running. + #[must_use] + pub fn cancel_signal(&self) -> Arc { + Arc::clone(&self.cancel_signal) + } + fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) { if let Some(ref tx) = self.metrics_tx { let elapsed = self.start_time.elapsed().as_secs(); @@ -620,6 +635,13 @@ impl Agent { } async fn process_user_message(&mut self, text: String) -> Result<(), error::AgentError> { + self.cancel_token = CancellationToken::new(); + let signal = Arc::clone(&self.cancel_signal); + let token = self.cancel_token.clone(); + tokio::spawn(async move { + signal.notified().await; + token.cancel(); + }); let trimmed = text.trim(); if trimmed == "/skills" { @@ -1931,4 +1953,60 @@ pub(super) mod agent_tests { let recent = &history[history.len() - DOOM_LOOP_WINDOW..]; assert!(!recent.windows(2).all(|w| w[0] == w[1])); } + + #[tokio::test] + async fn cancel_signal_propagates_to_fresh_token() { + use tokio_util::sync::CancellationToken; + let signal = Arc::new(Notify::new()); + + let token = CancellationToken::new(); + let sig = Arc::clone(&signal); + let tok = token.clone(); + tokio::spawn(async move { + sig.notified().await; + tok.cancel(); + }); + + // Yield to let the spawned task reach notified().await + tokio::task::yield_now().await; + assert!(!token.is_cancelled()); + signal.notify_waiters(); + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + assert!(token.is_cancelled()); + } + + #[tokio::test] + async fn cancel_signal_works_across_multiple_messages() { + use tokio_util::sync::CancellationToken; + let signal = Arc::new(Notify::new()); + + // First "message" + let token1 = CancellationToken::new(); + let sig1 = Arc::clone(&signal); + let tok1 = token1.clone(); + tokio::spawn(async move { + sig1.notified().await; + tok1.cancel(); + }); + + tokio::task::yield_now().await; + signal.notify_waiters(); + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + assert!(token1.is_cancelled()); + + // Second "message" — same signal, new token + let token2 = CancellationToken::new(); + let sig2 = Arc::clone(&signal); + let tok2 = token2.clone(); + tokio::spawn(async move { + sig2.notified().await; + tok2.cancel(); + }); + + tokio::task::yield_now().await; + assert!(!token2.is_cancelled()); + signal.notify_waiters(); + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + assert!(token2.is_cancelled()); + } } diff --git a/crates/zeph-core/src/agent/streaming.rs b/crates/zeph-core/src/agent/streaming.rs index c9da1082..50b13c72 100644 --- a/crates/zeph-core/src/agent/streaming.rs +++ b/crates/zeph-core/src/agent/streaming.rs @@ -77,6 +77,11 @@ impl Agent { self.doom_loop_history.clear(); for iteration in 0..self.runtime.max_tool_iterations { + if self.cancel_token.is_cancelled() { + tracing::info!("tool loop cancelled by user"); + break; + } + self.channel.send_typing().await?; // Context budget check at 80% threshold @@ -169,6 +174,10 @@ impl Agent { pub(crate) async fn call_llm_with_timeout( &mut self, ) -> Result, super::error::AgentError> { + if self.cancel_token.is_cancelled() { + return Ok(None); + } + if let Some(ref tracker) = self.cost_tracker && let Err(e) = tracker.check_budget() { @@ -184,12 +193,18 @@ impl Agent { let llm_span = tracing::info_span!("llm_call", model = %self.runtime.model_name); if self.provider.supports_streaming() { - if let Ok(r) = tokio::time::timeout( - llm_timeout, - self.process_response_streaming().instrument(llm_span), - ) - .await - { + let cancel = self.cancel_token.clone(); + let streaming_fut = self.process_response_streaming().instrument(llm_span); + let result = tokio::select! { + r = tokio::time::timeout(llm_timeout, streaming_fut) => r, + () = cancel.cancelled() => { + tracing::info!("LLM call cancelled by user"); + self.update_metrics(|m| m.cancellations += 1); + self.channel.send("[Cancelled]").await?; + return Ok(None); + } + }; + if let Ok(r) = result { let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX); let completion_estimate_for_cost = r .as_ref() @@ -211,12 +226,18 @@ impl Agent { Ok(None) } } else { - match tokio::time::timeout( - llm_timeout, - self.provider.chat(&self.messages).instrument(llm_span), - ) - .await - { + let cancel = self.cancel_token.clone(); + let chat_fut = self.provider.chat(&self.messages).instrument(llm_span); + let result = tokio::select! { + r = tokio::time::timeout(llm_timeout, chat_fut) => r, + () = cancel.cancelled() => { + tracing::info!("LLM call cancelled by user"); + self.update_metrics(|m| m.cancellations += 1); + self.channel.send("[Cancelled]").await?; + return Ok(None); + } + }; + match result { Ok(Ok(resp)) => { let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX); let completion_estimate = u64::try_from(resp.len()).unwrap_or(0) / 4; @@ -417,6 +438,12 @@ impl Agent { } Ok(false) } + Err(ToolError::Cancelled) => { + tracing::info!("tool execution cancelled"); + self.update_metrics(|m| m.cancellations += 1); + self.channel.send("[Cancelled]").await?; + Ok(false) + } Err(ToolError::SandboxViolation { path }) => { tracing::warn!("sandbox violation: {path}"); self.channel @@ -458,6 +485,10 @@ impl Agent { tracing::info!("streaming interrupted by shutdown"); break; } + () = self.cancel_token.cancelled() => { + tracing::info!("streaming interrupted by cancellation"); + break; + } }; let chunk: String = chunk_result?; response.push_str(&chunk); @@ -510,6 +541,10 @@ impl Agent { tracing::info!("native tool loop interrupted by shutdown"); break; } + if self.cancel_token.is_cancelled() { + tracing::info!("native tool loop cancelled by user"); + break; + } self.channel.send_typing().await?; @@ -595,14 +630,22 @@ impl Agent { let start = std::time::Instant::now(); let llm_span = tracing::info_span!("llm_call", model = %self.runtime.model_name); - let result = if let Ok(result) = tokio::time::timeout( + let chat_fut = tokio::time::timeout( llm_timeout, self.provider .chat_with_tools(&self.messages, tool_defs) .instrument(llm_span), - ) - .await - { + ); + let timeout_result = tokio::select! { + r = chat_fut => r, + () = self.cancel_token.cancelled() => { + tracing::info!("chat_with_tools cancelled by user"); + self.update_metrics(|m| m.cancellations += 1); + self.channel.send("[Cancelled]").await?; + return Ok(None); + } + }; + let result = if let Ok(result) = timeout_result { result? } else { self.channel @@ -686,28 +729,39 @@ impl Agent { }) .collect(); - // Execute tool calls in parallel + // Execute tool calls in parallel, with cancellation let max_parallel = self.runtime.timeouts.max_parallel_tools; - let tool_results = if calls.len() <= max_parallel { - let futs: Vec<_> = calls - .iter() - .zip(tool_calls.iter()) - .map(|(call, tc)| { - self.tool_executor.execute_tool_call(call).instrument( - tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), - ) - }) - .collect(); - futures::future::join_all(futs).await - } else { - use futures::StreamExt; - let stream = - futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| { - self.tool_executor.execute_tool_call(call).instrument( - tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), - ) - })); - futures::StreamExt::collect::>(stream.buffered(max_parallel)).await + let exec_fut = async { + if calls.len() <= max_parallel { + let futs: Vec<_> = calls + .iter() + .zip(tool_calls.iter()) + .map(|(call, tc)| { + self.tool_executor.execute_tool_call(call).instrument( + tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), + ) + }) + .collect(); + futures::future::join_all(futs).await + } else { + use futures::StreamExt; + let stream = + futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| { + self.tool_executor.execute_tool_call(call).instrument( + tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), + ) + })); + futures::StreamExt::collect::>(stream.buffered(max_parallel)).await + } + }; + let tool_results = tokio::select! { + results = exec_fut => results, + () = self.cancel_token.cancelled() => { + tracing::info!("tool execution cancelled by user"); + self.update_metrics(|m| m.cancellations += 1); + self.channel.send("[Cancelled]").await?; + return Ok(()); + } }; // Process results sequentially (metrics, channel sends, message parts) diff --git a/crates/zeph-core/src/metrics.rs b/crates/zeph-core/src/metrics.rs index 52667bd1..6a33ed07 100644 --- a/crates/zeph-core/src/metrics.rs +++ b/crates/zeph-core/src/metrics.rs @@ -34,6 +34,7 @@ pub struct MetricsSnapshot { pub filter_confidence_full: u64, pub filter_confidence_partial: u64, pub filter_confidence_fallback: u64, + pub cancellations: u64, } pub struct MetricsCollector { @@ -143,4 +144,13 @@ mod tests { collector.update(|m| m.summaries_count += 1); assert_eq!(rx.borrow().summaries_count, 2); } + + #[test] + fn cancellations_counter_increments() { + let (collector, rx) = MetricsCollector::new(); + assert_eq!(rx.borrow().cancellations, 0); + collector.update(|m| m.cancellations += 1); + collector.update(|m| m.cancellations += 1); + assert_eq!(rx.borrow().cancellations, 2); + } } diff --git a/crates/zeph-tools/Cargo.toml b/crates/zeph-tools/Cargo.toml index 1514a5f3..74c0b111 100644 --- a/crates/zeph-tools/Cargo.toml +++ b/crates/zeph-tools/Cargo.toml @@ -18,6 +18,7 @@ serde = { workspace = true, features = ["derive"] } serde_json.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["fs", "io-util", "macros", "process", "rt", "sync", "time"] } +tokio-util.workspace = true tracing.workspace = true url.workspace = true zeph-skills.workspace = true diff --git a/crates/zeph-tools/src/executor.rs b/crates/zeph-tools/src/executor.rs index ba7f1d69..04441e69 100644 --- a/crates/zeph-tools/src/executor.rs +++ b/crates/zeph-tools/src/executor.rs @@ -143,6 +143,9 @@ pub enum ToolError { #[error("command timed out after {timeout_secs}s")] Timeout { timeout_secs: u64 }, + #[error("operation cancelled")] + Cancelled, + #[error("execution failed: {0}")] Execution(#[from] std::io::Error), } diff --git a/crates/zeph-tools/src/shell.rs b/crates/zeph-tools/src/shell.rs index 89c09fed..97cefa9b 100644 --- a/crates/zeph-tools/src/shell.rs +++ b/crates/zeph-tools/src/shell.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::time::{Duration, Instant}; use tokio::process::Command; +use tokio_util::sync::CancellationToken; use schemars::JsonSchema; @@ -39,6 +40,7 @@ pub struct ShellExecutor { tool_event_tx: Option, permission_policy: Option, output_filter_registry: Option, + cancel_token: Option, } impl ShellExecutor { @@ -84,6 +86,7 @@ impl ShellExecutor { tool_event_tx: None, permission_policy: None, output_filter_registry: None, + cancel_token: None, } } @@ -105,6 +108,12 @@ impl ShellExecutor { self } + #[must_use] + pub fn with_cancel_token(mut self, token: CancellationToken) -> Self { + self.cancel_token = Some(token); + self + } + #[must_use] pub fn with_output_filters(mut self, registry: OutputFilterRegistry) -> Self { self.output_filter_registry = Some(registry); @@ -191,8 +200,21 @@ impl ShellExecutor { } let start = Instant::now(); - let (out, exit_code) = - execute_bash(block, self.timeout, self.tool_event_tx.as_ref()).await; + let (out, exit_code) = execute_bash( + block, + self.timeout, + self.tool_event_tx.as_ref(), + self.cancel_token.as_ref(), + ) + .await; + if exit_code == 130 + && self + .cancel_token + .as_ref() + .is_some_and(CancellationToken::is_cancelled) + { + return Err(ToolError::Cancelled); + } #[allow(clippy::cast_possible_truncation)] let duration_ms = start.elapsed().as_millis() as u64; @@ -398,10 +420,25 @@ fn chrono_now() -> String { format!("{secs}") } +/// Kill a child process and its descendants. +/// On unix, sends SIGKILL to child processes via `pkill -KILL -P ` before +/// killing the parent, preventing zombie subprocesses. +async fn kill_process_tree(child: &mut tokio::process::Child) { + #[cfg(unix)] + if let Some(pid) = child.id() { + let _ = Command::new("pkill") + .args(["-KILL", "-P", &pid.to_string()]) + .status() + .await; + } + let _ = child.kill().await; +} + async fn execute_bash( code: &str, timeout: Duration, event_tx: Option<&ToolEventTx>, + cancel_token: Option<&CancellationToken>, ) -> (String, i32) { use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; @@ -465,9 +502,18 @@ async fn execute_bash( } } () = tokio::time::sleep_until(deadline) => { - let _ = child.kill().await; + kill_process_tree(&mut child).await; return (format!("[error] command timed out after {timeout_secs}s"), 1); } + () = async { + match cancel_token { + Some(t) => t.cancelled().await, + None => std::future::pending().await, + } + } => { + kill_process_tree(&mut child).await; + return ("[cancelled] operation aborted".to_string(), 130); + } } } @@ -541,7 +587,7 @@ mod tests { #[tokio::test] #[cfg(not(target_os = "windows"))] async fn execute_simple_command() { - let (result, code) = execute_bash("echo hello", Duration::from_secs(30), None).await; + let (result, code) = execute_bash("echo hello", Duration::from_secs(30), None, None).await; assert!(result.contains("hello")); assert_eq!(code, 0); } @@ -549,7 +595,7 @@ mod tests { #[tokio::test] #[cfg(not(target_os = "windows"))] async fn execute_stderr_output() { - let (result, _) = execute_bash("echo err >&2", Duration::from_secs(30), None).await; + let (result, _) = execute_bash("echo err >&2", Duration::from_secs(30), None, None).await; assert!(result.contains("[stderr]")); assert!(result.contains("err")); } @@ -557,8 +603,13 @@ mod tests { #[tokio::test] #[cfg(not(target_os = "windows"))] async fn execute_stdout_and_stderr_combined() { - let (result, _) = - execute_bash("echo out && echo err >&2", Duration::from_secs(30), None).await; + let (result, _) = execute_bash( + "echo out && echo err >&2", + Duration::from_secs(30), + None, + None, + ) + .await; assert!(result.contains("out")); assert!(result.contains("[stderr]")); assert!(result.contains("err")); @@ -568,7 +619,7 @@ mod tests { #[tokio::test] #[cfg(not(target_os = "windows"))] async fn execute_empty_output() { - let (result, code) = execute_bash("true", Duration::from_secs(30), None).await; + let (result, code) = execute_bash("true", Duration::from_secs(30), None, None).await; assert_eq!(result, "(no output)"); assert_eq!(code, 0); } @@ -1141,7 +1192,7 @@ mod tests { #[cfg(unix)] #[tokio::test] async fn execute_bash_error_handling() { - let (result, code) = execute_bash("false", Duration::from_secs(5), None).await; + let (result, code) = execute_bash("false", Duration::from_secs(5), None, None).await; assert_eq!(result, "(no output)"); assert_eq!(code, 1); } @@ -1149,8 +1200,13 @@ mod tests { #[cfg(unix)] #[tokio::test] async fn execute_bash_command_not_found() { - let (result, _) = - execute_bash("nonexistent-command-xyz", Duration::from_secs(5), None).await; + let (result, _) = execute_bash( + "nonexistent-command-xyz", + Duration::from_secs(5), + None, + None, + ) + .await; assert!(result.contains("[stderr]") || result.contains("[error]")); } @@ -1240,4 +1296,66 @@ mod tests { let req = obj["required"].as_array().unwrap(); assert!(req.iter().any(|v| v.as_str() == Some("command"))); } + + #[tokio::test] + #[cfg(not(target_os = "windows"))] + async fn cancel_token_kills_child_process() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(100)).await; + token_clone.cancel(); + }); + let (result, code) = + execute_bash("sleep 60", Duration::from_secs(30), None, Some(&token)).await; + assert_eq!(code, 130); + assert!(result.contains("[cancelled]")); + } + + #[tokio::test] + #[cfg(not(target_os = "windows"))] + async fn cancel_token_none_does_not_cancel() { + let (result, code) = execute_bash("echo ok", Duration::from_secs(5), None, None).await; + assert_eq!(code, 0); + assert!(result.contains("ok")); + } + + #[tokio::test] + #[cfg(not(target_os = "windows"))] + async fn cancel_kills_child_process_group() { + use std::path::Path; + let marker = format!("/tmp/zeph-pgkill-test-{}", std::process::id()); + let script = format!("bash -c 'sleep 30 && touch {marker}' & sleep 60"); + let token = CancellationToken::new(); + let token_clone = token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + token_clone.cancel(); + }); + let (result, code) = + execute_bash(&script, Duration::from_secs(30), None, Some(&token)).await; + assert_eq!(code, 130); + assert!(result.contains("[cancelled]")); + // Wait briefly, then verify the subprocess did NOT create the marker file + tokio::time::sleep(Duration::from_millis(500)).await; + assert!( + !Path::new(&marker).exists(), + "subprocess should have been killed with process group" + ); + } + + #[tokio::test] + #[cfg(not(target_os = "windows"))] + async fn shell_executor_cancel_returns_cancelled_error() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(100)).await; + token_clone.cancel(); + }); + let executor = ShellExecutor::new(&default_config()).with_cancel_token(token); + let response = "```bash\nsleep 60\n```"; + let result = executor.execute(response).await; + assert!(matches!(result, Err(ToolError::Cancelled))); + } } diff --git a/crates/zeph-tui/src/app.rs b/crates/zeph-tui/src/app.rs index c6afc739..10677448 100644 --- a/crates/zeph-tui/src/app.rs +++ b/crates/zeph-tui/src/app.rs @@ -1,5 +1,7 @@ +use std::sync::Arc; + use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{Notify, mpsc, oneshot, watch}; use tracing::debug; use crate::event::{AgentEvent, AppEvent}; @@ -73,6 +75,7 @@ pub struct App { draft_input: String, queued_count: usize, hyperlinks: Vec, + cancel_signal: Option>, } impl App { @@ -107,6 +110,7 @@ impl App { draft_input: String::new(), queued_count: 0, hyperlinks: Vec::new(), + cancel_signal: None, } } @@ -157,6 +161,12 @@ impl App { } } + #[must_use] + pub fn with_cancel_signal(mut self, signal: Arc) -> Self { + self.cancel_signal = Some(signal); + self + } + #[must_use] pub fn with_metrics_rx(mut self, rx: watch::Receiver) -> Self { self.metrics_rx = Some(rx); @@ -233,6 +243,11 @@ impl App { self.queued_count } + #[must_use] + pub fn is_agent_busy(&self) -> bool { + self.status_label.is_some() || self.messages.last().is_some_and(|m| m.streaming) + } + #[must_use] pub fn has_running_tool(&self) -> bool { self.messages @@ -524,6 +539,11 @@ impl App { fn handle_normal_key(&mut self, key: KeyEvent) { match key.code { + KeyCode::Esc if self.is_agent_busy() => { + if let Some(ref signal) = self.cancel_signal { + signal.notify_waiters(); + } + } KeyCode::Char('q') => self.should_quit = true, KeyCode::Char('i') => self.input_mode = InputMode::Insert, KeyCode::Up | KeyCode::Char('k') => { @@ -1311,4 +1331,39 @@ mod tests { assert!(!app.show_help); assert_eq!(app.input(), "?"); } + + #[tokio::test] + async fn esc_in_normal_mode_cancels_when_busy() { + let (mut app, _rx, _tx) = make_app(); + let notify = Arc::new(Notify::new()); + let notify_waiter = Arc::clone(¬ify); + let handle = tokio::spawn(async move { + notify_waiter.notified().await; + true + }); + tokio::task::yield_now().await; + + app = app.with_cancel_signal(Arc::clone(¬ify)); + app.input_mode = InputMode::Normal; + app.status_label = Some("Thinking...".into()); + assert!(app.is_agent_busy()); + + let key = KeyEvent::new(KeyCode::Esc, KeyModifiers::NONE); + app.handle_event(AppEvent::Key(key)).unwrap(); + let result = tokio::time::timeout(std::time::Duration::from_millis(100), handle).await; + assert!(result.is_ok(), "notify should have been triggered"); + } + + #[test] + fn esc_in_normal_mode_does_not_cancel_when_idle() { + let (mut app, _rx, _tx) = make_app(); + let notify = Arc::new(Notify::new()); + app = app.with_cancel_signal(notify); + app.input_mode = InputMode::Normal; + assert!(!app.is_agent_busy()); + + let key = KeyEvent::new(KeyCode::Esc, KeyModifiers::NONE); + app.handle_event(AppEvent::Key(key)).unwrap(); + // No way to assert "not notified" directly, but we verify no panic + } } diff --git a/crates/zeph-tui/src/widgets/status.rs b/crates/zeph-tui/src/widgets/status.rs index a595f592..c4c56b0c 100644 --- a/crates/zeph-tui/src/widgets/status.rs +++ b/crates/zeph-tui/src/widgets/status.rs @@ -21,8 +21,14 @@ pub fn render(app: &App, metrics: &MetricsSnapshot, frame: &mut Frame, area: Rec let panel = if app.show_side_panels() { "ON" } else { "OFF" }; + let cancel_hint = if app.is_agent_busy() && app.input_mode() == InputMode::Normal { + " | [Esc to cancel]" + } else { + "" + }; + let text = format!( - " [{mode}] | Panel: {panel} | Skills: {active}/{total} | Tokens: {tok} | Qdrant: {qdrant} | API: {api} | {uptime}", + " [{mode}] | Panel: {panel} | Skills: {active}/{total} | Tokens: {tok} | Qdrant: {qdrant} | API: {api} | {uptime}{cancel_hint}", active = metrics.active_skills.len(), total = metrics.total_skills, tok = format_tokens(metrics.total_tokens), diff --git a/src/main.rs b/src/main.rs index 7874d51e..c8b513a1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -165,14 +165,6 @@ async fn main() -> anyhow::Result<()> { tracing::info!("conversation id: {conversation_id}"); let (shutdown_tx, shutdown_rx) = AppBuilder::build_shutdown(); - tokio::spawn(async move { - if let Err(e) = tokio::signal::ctrl_c().await { - tracing::error!("failed to listen for ctrl-c: {e:#}"); - return; - } - tracing::info!("received shutdown signal"); - let _ = shutdown_tx.send(true); - }); tokio::task::spawn_blocking(|| { zeph_tools::cleanup_overflow_files(std::time::Duration::from_secs(86_400)); @@ -406,6 +398,28 @@ async fn main() -> anyhow::Result<()> { let mut agent = agent; + // Double Ctrl+C: first cancels current operation, second within 2s shuts down + let cancel_signal = agent.cancel_signal(); + tokio::spawn(async move { + let mut last_ctrl_c: Option = None; + loop { + if tokio::signal::ctrl_c().await.is_err() { + break; + } + let now = tokio::time::Instant::now(); + if let Some(prev) = last_ctrl_c + && now.duration_since(prev) < std::time::Duration::from_secs(2) + { + tracing::info!("received second ctrl-c, shutting down"); + let _ = shutdown_tx.send(true); + break; + } + tracing::info!("received ctrl-c, cancelling current operation"); + cancel_signal.notify_waiters(); + last_ctrl_c = Some(now); + } + }); + agent.load_history().await?; #[cfg(feature = "tui")] @@ -415,7 +429,8 @@ async fn main() -> anyhow::Result<()> { let reader = EventReader::new(event_tx, Duration::from_millis(100)); std::thread::spawn(move || reader.run()); - let mut tui_app = App::new(tui_handle.user_tx, tui_handle.agent_rx); + let mut tui_app = App::new(tui_handle.user_tx, tui_handle.agent_rx) + .with_cancel_signal(agent.cancel_signal()); tui_app.set_show_source_labels(config.tui.show_source_labels); let history: Vec<(&str, &str)> = agent