diff --git a/Cargo.lock b/Cargo.lock index 28f388cb..46294c3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8369,6 +8369,7 @@ dependencies = [ "age", "anyhow", "criterion", + "futures", "notify", "notify-debouncer-mini", "serde", diff --git a/config/default.toml b/config/default.toml index 17c31a71..72800cb6 100644 --- a/config/default.toml +++ b/config/default.toml @@ -318,3 +318,5 @@ llm_seconds = 120 embedding_seconds = 30 # A2A remote call timeout in seconds a2a_seconds = 30 +# Maximum number of tool calls to execute in parallel +max_parallel_tools = 8 diff --git a/crates/zeph-core/Cargo.toml b/crates/zeph-core/Cargo.toml index 2e65a7d3..5e450c70 100644 --- a/crates/zeph-core/Cargo.toml +++ b/crates/zeph-core/Cargo.toml @@ -17,6 +17,7 @@ metal = ["zeph-llm/metal"] [dependencies] age.workspace = true anyhow.workspace = true +futures.workspace = true notify.workspace = true notify-debouncer-mini.workspace = true serde = { workspace = true, features = ["derive"] } diff --git a/crates/zeph-core/src/agent/streaming.rs b/crates/zeph-core/src/agent/streaming.rs index 2a77dd9c..7607652d 100644 --- a/crates/zeph-core/src/agent/streaming.rs +++ b/crates/zeph-core/src/agent/streaming.rs @@ -617,25 +617,50 @@ impl Agent { .await; self.push_message(assistant_msg); - let mut result_parts: Vec = Vec::new(); - for tc in tool_calls { - let params: std::collections::HashMap = - if let serde_json::Value::Object(map) = &tc.input { - map.iter().map(|(k, v)| (k.clone(), v.clone())).collect() - } else { - std::collections::HashMap::new() - }; + // Build tool calls for all requests + let calls: Vec = tool_calls + .iter() + .map(|tc| { + let params: std::collections::HashMap = + if let serde_json::Value::Object(map) = &tc.input { + map.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } else { + std::collections::HashMap::new() + }; + ToolCall { + tool_id: tc.name.clone(), + params, + } + }) + .collect(); - let call = ToolCall { - tool_id: tc.name.clone(), - params, - }; + // Execute tool calls in parallel + let max_parallel = self.runtime.timeouts.max_parallel_tools; + let tool_results = if calls.len() <= max_parallel { + let futs: Vec<_> = calls + .iter() + .zip(tool_calls.iter()) + .map(|(call, tc)| { + self.tool_executor.execute_tool_call(call).instrument( + tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), + ) + }) + .collect(); + futures::future::join_all(futs).await + } else { + use futures::StreamExt; + let stream = + futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| { + self.tool_executor.execute_tool_call(call).instrument( + tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), + ) + })); + futures::StreamExt::collect::>(stream.buffered(max_parallel)).await + }; - let tool_result = self - .tool_executor - .execute_tool_call(&call) - .instrument(tracing::info_span!("tool_exec", tool_name = %tc.name)) - .await; + // Process results sequentially (metrics, channel sends, message parts) + let mut result_parts: Vec = Vec::new(); + for (tc, tool_result) in tool_calls.iter().zip(tool_results) { let (output, is_error, inline_stats) = match tool_result { Ok(Some(out)) => { if let Some(ref fs) = out.filter_stats { @@ -735,3 +760,179 @@ fn tool_def_to_definition(def: &zeph_tools::registry::ToolDef) -> ToolDefinition parameters: serde_json::to_value(&def.schema).unwrap_or_default(), } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::{Duration, Instant}; + + use futures::future::join_all; + use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput}; + + struct DelayExecutor { + delay: Duration, + call_order: Arc, + } + + impl ToolExecutor for DelayExecutor { + fn execute( + &self, + _response: &str, + ) -> impl Future, ToolError>> + Send { + std::future::ready(Ok(None)) + } + + fn execute_tool_call( + &self, + call: &ToolCall, + ) -> impl Future, ToolError>> + Send { + let delay = self.delay; + let order = self.call_order.clone(); + let idx = order.fetch_add(1, Ordering::SeqCst); + let tool_id = call.tool_id.clone(); + async move { + tokio::time::sleep(delay).await; + Ok(Some(ToolOutput { + tool_name: tool_id, + summary: format!("result-{idx}"), + blocks_executed: 1, + diff: None, + filter_stats: None, + })) + } + } + } + + struct FailingNthExecutor { + fail_index: usize, + call_count: AtomicUsize, + } + + impl ToolExecutor for FailingNthExecutor { + fn execute( + &self, + _response: &str, + ) -> impl Future, ToolError>> + Send { + std::future::ready(Ok(None)) + } + + fn execute_tool_call( + &self, + call: &ToolCall, + ) -> impl Future, ToolError>> + Send { + let idx = self.call_count.fetch_add(1, Ordering::SeqCst); + let fail = idx == self.fail_index; + let tool_id = call.tool_id.clone(); + async move { + if fail { + Err(ToolError::Execution(std::io::Error::new( + std::io::ErrorKind::Other, + format!("tool {tool_id} failed"), + ))) + } else { + Ok(Some(ToolOutput { + tool_name: tool_id, + summary: format!("ok-{idx}"), + blocks_executed: 1, + diff: None, + filter_stats: None, + })) + } + } + } + } + + fn make_calls(n: usize) -> Vec { + (0..n) + .map(|i| ToolCall { + tool_id: format!("tool-{i}"), + params: HashMap::new(), + }) + .collect() + } + + #[tokio::test] + async fn parallel_preserves_result_order() { + let executor = DelayExecutor { + delay: Duration::from_millis(10), + call_order: Arc::new(AtomicUsize::new(0)), + }; + let calls = make_calls(5); + + let futs: Vec<_> = calls + .iter() + .map(|c| executor.execute_tool_call(c)) + .collect(); + let results = join_all(futs).await; + + for (i, r) in results.iter().enumerate() { + let out = r.as_ref().unwrap().as_ref().unwrap(); + assert_eq!(out.tool_name, format!("tool-{i}")); + } + } + + #[tokio::test] + async fn parallel_faster_than_sequential() { + let executor = DelayExecutor { + delay: Duration::from_millis(50), + call_order: Arc::new(AtomicUsize::new(0)), + }; + let calls = make_calls(4); + + let start = Instant::now(); + let futs: Vec<_> = calls + .iter() + .map(|c| executor.execute_tool_call(c)) + .collect(); + let _results = join_all(futs).await; + let parallel_time = start.elapsed(); + + // Sequential would take >= 200ms (4 * 50ms); parallel should be ~50ms + assert!( + parallel_time < Duration::from_millis(150), + "parallel took {parallel_time:?}, expected < 150ms" + ); + } + + #[tokio::test] + async fn one_failure_does_not_block_others() { + let executor = FailingNthExecutor { + fail_index: 1, + call_count: AtomicUsize::new(0), + }; + let calls = make_calls(3); + + let futs: Vec<_> = calls + .iter() + .map(|c| executor.execute_tool_call(c)) + .collect(); + let results = join_all(futs).await; + + assert!(results[0].is_ok()); + assert!(results[1].is_err()); + assert!(results[2].is_ok()); + } + + #[tokio::test] + async fn buffered_preserves_order() { + use futures::StreamExt; + + let executor = DelayExecutor { + delay: Duration::from_millis(10), + call_order: Arc::new(AtomicUsize::new(0)), + }; + let calls = make_calls(6); + let max_parallel = 2; + + let stream = futures::stream::iter(calls.iter().map(|c| executor.execute_tool_call(c))); + let results: Vec<_> = + futures::StreamExt::collect::>(stream.buffered(max_parallel)).await; + + for (i, r) in results.iter().enumerate() { + let out = r.as_ref().unwrap().as_ref().unwrap(); + assert_eq!(out.tool_name, format!("tool-{i}")); + } + } +} diff --git a/crates/zeph-core/src/config/types.rs b/crates/zeph-core/src/config/types.rs index eb75ff63..ee40f905 100644 --- a/crates/zeph-core/src/config/types.rs +++ b/crates/zeph-core/src/config/types.rs @@ -674,6 +674,10 @@ fn default_a2a_timeout() -> u64 { 30 } +fn default_max_parallel_tools() -> usize { + 8 +} + #[derive(Debug, Clone, Copy, Deserialize)] pub struct SecurityConfig { #[serde(default = "default_true")] @@ -699,6 +703,8 @@ pub struct TimeoutConfig { pub embedding_seconds: u64, #[serde(default = "default_a2a_timeout")] pub a2a_seconds: u64, + #[serde(default = "default_max_parallel_tools")] + pub max_parallel_tools: usize, } impl Default for TimeoutConfig { @@ -707,6 +713,7 @@ impl Default for TimeoutConfig { llm_seconds: default_llm_timeout(), embedding_seconds: default_embedding_timeout(), a2a_seconds: default_a2a_timeout(), + max_parallel_tools: default_max_parallel_tools(), } } } diff --git a/tests/integration.rs b/tests/integration.rs index 121b167e..9fec248b 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1269,6 +1269,7 @@ async fn agent_with_security_config() { llm_seconds: 60, embedding_seconds: 15, a2a_seconds: 10, + max_parallel_tools: 8, }; let mut agent = Agent::new(