Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ mock = ["zeph-llm/mock"]

[dependencies]
anyhow.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "sync"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "sync", "time"] }
tokio-util.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
opentelemetry = { workspace = true, optional = true }
Expand Down
1 change: 1 addition & 0 deletions crates/zeph-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["fs", "macros", "rt-multi-thread", "sync", "time"] }
tokio-util.workspace = true
tokio-stream.workspace = true
toml.workspace = true
tracing.workspace = true
Expand Down
80 changes: 79 additions & 1 deletion crates/zeph-core/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use std::collections::VecDeque;
use std::path::PathBuf;
use std::time::{Duration, Instant};

use tokio::sync::{mpsc, watch};
use std::sync::Arc;

use tokio::sync::{Notify, mpsc, watch};
use tokio_util::sync::CancellationToken;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::{LlmProvider, Message, Role};

Expand Down Expand Up @@ -123,6 +126,8 @@ pub struct Agent<C: Channel, T: ToolExecutor> {
pub(super) mcp: McpState,
#[cfg(feature = "index")]
pub(super) index: IndexState,
cancel_signal: Arc<Notify>,
cancel_token: CancellationToken,
start_time: Instant,
message_queue: VecDeque<QueuedMessage>,
summary_provider: Option<AnyProvider>,
Expand Down Expand Up @@ -216,6 +221,8 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
cached_repo_map: None,
repo_map_ttl: std::time::Duration::from_secs(300),
},
cancel_signal: Arc::new(Notify::new()),
cancel_token: CancellationToken::new(),
start_time: Instant::now(),
message_queue: VecDeque::new(),
summary_provider: None,
Expand Down Expand Up @@ -426,6 +433,14 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
self
}

/// Returns a handle that can cancel the current in-flight operation.
/// The returned `Notify` is stable across messages — callers invoke
/// `notify_waiters()` to cancel whatever operation is running.
#[must_use]
pub fn cancel_signal(&self) -> Arc<Notify> {
Arc::clone(&self.cancel_signal)
}

fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
if let Some(ref tx) = self.metrics_tx {
let elapsed = self.start_time.elapsed().as_secs();
Expand Down Expand Up @@ -620,6 +635,13 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
}

async fn process_user_message(&mut self, text: String) -> Result<(), error::AgentError> {
self.cancel_token = CancellationToken::new();
let signal = Arc::clone(&self.cancel_signal);
let token = self.cancel_token.clone();
tokio::spawn(async move {
signal.notified().await;
token.cancel();
});
let trimmed = text.trim();

if trimmed == "/skills" {
Expand Down Expand Up @@ -1931,4 +1953,60 @@ pub(super) mod agent_tests {
let recent = &history[history.len() - DOOM_LOOP_WINDOW..];
assert!(!recent.windows(2).all(|w| w[0] == w[1]));
}

#[tokio::test]
async fn cancel_signal_propagates_to_fresh_token() {
use tokio_util::sync::CancellationToken;
let signal = Arc::new(Notify::new());

let token = CancellationToken::new();
let sig = Arc::clone(&signal);
let tok = token.clone();
tokio::spawn(async move {
sig.notified().await;
tok.cancel();
});

// Yield to let the spawned task reach notified().await
tokio::task::yield_now().await;
assert!(!token.is_cancelled());
signal.notify_waiters();
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert!(token.is_cancelled());
}

#[tokio::test]
async fn cancel_signal_works_across_multiple_messages() {
use tokio_util::sync::CancellationToken;
let signal = Arc::new(Notify::new());

// First "message"
let token1 = CancellationToken::new();
let sig1 = Arc::clone(&signal);
let tok1 = token1.clone();
tokio::spawn(async move {
sig1.notified().await;
tok1.cancel();
});

tokio::task::yield_now().await;
signal.notify_waiters();
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert!(token1.is_cancelled());

// Second "message" — same signal, new token
let token2 = CancellationToken::new();
let sig2 = Arc::clone(&signal);
let tok2 = token2.clone();
tokio::spawn(async move {
sig2.notified().await;
tok2.cancel();
});

tokio::task::yield_now().await;
assert!(!token2.is_cancelled());
signal.notify_waiters();
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert!(token2.is_cancelled());
}
}
128 changes: 91 additions & 37 deletions crates/zeph-core/src/agent/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
self.doom_loop_history.clear();

for iteration in 0..self.runtime.max_tool_iterations {
if self.cancel_token.is_cancelled() {
tracing::info!("tool loop cancelled by user");
break;
}

self.channel.send_typing().await?;

// Context budget check at 80% threshold
Expand Down Expand Up @@ -169,6 +174,10 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
pub(crate) async fn call_llm_with_timeout(
&mut self,
) -> Result<Option<String>, super::error::AgentError> {
if self.cancel_token.is_cancelled() {
return Ok(None);
}

if let Some(ref tracker) = self.cost_tracker
&& let Err(e) = tracker.check_budget()
{
Expand All @@ -184,12 +193,18 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {

let llm_span = tracing::info_span!("llm_call", model = %self.runtime.model_name);
if self.provider.supports_streaming() {
if let Ok(r) = tokio::time::timeout(
llm_timeout,
self.process_response_streaming().instrument(llm_span),
)
.await
{
let cancel = self.cancel_token.clone();
let streaming_fut = self.process_response_streaming().instrument(llm_span);
let result = tokio::select! {
r = tokio::time::timeout(llm_timeout, streaming_fut) => r,
() = cancel.cancelled() => {
tracing::info!("LLM call cancelled by user");
self.update_metrics(|m| m.cancellations += 1);
self.channel.send("[Cancelled]").await?;
return Ok(None);
}
};
if let Ok(r) = result {
let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
let completion_estimate_for_cost = r
.as_ref()
Expand All @@ -211,12 +226,18 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
Ok(None)
}
} else {
match tokio::time::timeout(
llm_timeout,
self.provider.chat(&self.messages).instrument(llm_span),
)
.await
{
let cancel = self.cancel_token.clone();
let chat_fut = self.provider.chat(&self.messages).instrument(llm_span);
let result = tokio::select! {
r = tokio::time::timeout(llm_timeout, chat_fut) => r,
() = cancel.cancelled() => {
tracing::info!("LLM call cancelled by user");
self.update_metrics(|m| m.cancellations += 1);
self.channel.send("[Cancelled]").await?;
return Ok(None);
}
};
match result {
Ok(Ok(resp)) => {
let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
let completion_estimate = u64::try_from(resp.len()).unwrap_or(0) / 4;
Expand Down Expand Up @@ -417,6 +438,12 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
}
Ok(false)
}
Err(ToolError::Cancelled) => {
tracing::info!("tool execution cancelled");
self.update_metrics(|m| m.cancellations += 1);
self.channel.send("[Cancelled]").await?;
Ok(false)
}
Err(ToolError::SandboxViolation { path }) => {
tracing::warn!("sandbox violation: {path}");
self.channel
Expand Down Expand Up @@ -458,6 +485,10 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
tracing::info!("streaming interrupted by shutdown");
break;
}
() = self.cancel_token.cancelled() => {
tracing::info!("streaming interrupted by cancellation");
break;
}
};
let chunk: String = chunk_result?;
response.push_str(&chunk);
Expand Down Expand Up @@ -510,6 +541,10 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
tracing::info!("native tool loop interrupted by shutdown");
break;
}
if self.cancel_token.is_cancelled() {
tracing::info!("native tool loop cancelled by user");
break;
}

self.channel.send_typing().await?;

Expand Down Expand Up @@ -595,14 +630,22 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
let start = std::time::Instant::now();

let llm_span = tracing::info_span!("llm_call", model = %self.runtime.model_name);
let result = if let Ok(result) = tokio::time::timeout(
let chat_fut = tokio::time::timeout(
llm_timeout,
self.provider
.chat_with_tools(&self.messages, tool_defs)
.instrument(llm_span),
)
.await
{
);
let timeout_result = tokio::select! {
r = chat_fut => r,
() = self.cancel_token.cancelled() => {
tracing::info!("chat_with_tools cancelled by user");
self.update_metrics(|m| m.cancellations += 1);
self.channel.send("[Cancelled]").await?;
return Ok(None);
}
};
let result = if let Ok(result) = timeout_result {
result?
} else {
self.channel
Expand Down Expand Up @@ -686,28 +729,39 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
})
.collect();

// Execute tool calls in parallel
// Execute tool calls in parallel, with cancellation
let max_parallel = self.runtime.timeouts.max_parallel_tools;
let tool_results = if calls.len() <= max_parallel {
let futs: Vec<_> = calls
.iter()
.zip(tool_calls.iter())
.map(|(call, tc)| {
self.tool_executor.execute_tool_call(call).instrument(
tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id),
)
})
.collect();
futures::future::join_all(futs).await
} else {
use futures::StreamExt;
let stream =
futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| {
self.tool_executor.execute_tool_call(call).instrument(
tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id),
)
}));
futures::StreamExt::collect::<Vec<_>>(stream.buffered(max_parallel)).await
let exec_fut = async {
if calls.len() <= max_parallel {
let futs: Vec<_> = calls
.iter()
.zip(tool_calls.iter())
.map(|(call, tc)| {
self.tool_executor.execute_tool_call(call).instrument(
tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id),
)
})
.collect();
futures::future::join_all(futs).await
} else {
use futures::StreamExt;
let stream =
futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| {
self.tool_executor.execute_tool_call(call).instrument(
tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id),
)
}));
futures::StreamExt::collect::<Vec<_>>(stream.buffered(max_parallel)).await
}
};
let tool_results = tokio::select! {
results = exec_fut => results,
() = self.cancel_token.cancelled() => {
tracing::info!("tool execution cancelled by user");
self.update_metrics(|m| m.cancellations += 1);
self.channel.send("[Cancelled]").await?;
return Ok(());
}
};

// Process results sequentially (metrics, channel sends, message parts)
Expand Down
10 changes: 10 additions & 0 deletions crates/zeph-core/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct MetricsSnapshot {
pub filter_confidence_full: u64,
pub filter_confidence_partial: u64,
pub filter_confidence_fallback: u64,
pub cancellations: u64,
}

pub struct MetricsCollector {
Expand Down Expand Up @@ -143,4 +144,13 @@ mod tests {
collector.update(|m| m.summaries_count += 1);
assert_eq!(rx.borrow().summaries_count, 2);
}

#[test]
fn cancellations_counter_increments() {
let (collector, rx) = MetricsCollector::new();
assert_eq!(rx.borrow().cancellations, 0);
collector.update(|m| m.cancellations += 1);
collector.update(|m| m.cancellations += 1);
assert_eq!(rx.borrow().cancellations, 2);
}
}
1 change: 1 addition & 0 deletions crates/zeph-tools/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["fs", "io-util", "macros", "process", "rt", "sync", "time"] }
tokio-util.workspace = true
tracing.workspace = true
url.workspace = true
zeph-skills.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions crates/zeph-tools/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ pub enum ToolError {
#[error("command timed out after {timeout_secs}s")]
Timeout { timeout_secs: u64 },

#[error("operation cancelled")]
Cancelled,

#[error("execution failed: {0}")]
Execution(#[from] std::io::Error),
}
Expand Down
Loading
Loading