Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,5 @@ llm_seconds = 120
embedding_seconds = 30
# A2A remote call timeout in seconds
a2a_seconds = 30
# Maximum number of tool calls to execute in parallel
max_parallel_tools = 8
1 change: 1 addition & 0 deletions crates/zeph-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ metal = ["zeph-llm/metal"]
[dependencies]
age.workspace = true
anyhow.workspace = true
futures.workspace = true
notify.workspace = true
notify-debouncer-mini.workspace = true
serde = { workspace = true, features = ["derive"] }
Expand Down
235 changes: 218 additions & 17 deletions crates/zeph-core/src/agent/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,25 +617,50 @@ impl<C: Channel, T: ToolExecutor> Agent<C, T> {
.await;
self.push_message(assistant_msg);

let mut result_parts: Vec<MessagePart> = Vec::new();
for tc in tool_calls {
let params: std::collections::HashMap<String, serde_json::Value> =
if let serde_json::Value::Object(map) = &tc.input {
map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
} else {
std::collections::HashMap::new()
};
// Build tool calls for all requests
let calls: Vec<ToolCall> = tool_calls
.iter()
.map(|tc| {
let params: std::collections::HashMap<String, serde_json::Value> =
if let serde_json::Value::Object(map) = &tc.input {
map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
} else {
std::collections::HashMap::new()
};
ToolCall {
tool_id: tc.name.clone(),
params,
}
})
.collect();

let call = ToolCall {
tool_id: tc.name.clone(),
params,
};
// Execute tool calls in parallel
let max_parallel = self.runtime.timeouts.max_parallel_tools;
let tool_results = if calls.len() <= max_parallel {
let futs: Vec<_> = calls
.iter()
.zip(tool_calls.iter())
.map(|(call, tc)| {
self.tool_executor.execute_tool_call(call).instrument(
tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id),
)
})
.collect();
futures::future::join_all(futs).await
} else {
use futures::StreamExt;
let stream =
futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| {
self.tool_executor.execute_tool_call(call).instrument(
tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id),
)
}));
futures::StreamExt::collect::<Vec<_>>(stream.buffered(max_parallel)).await
};

let tool_result = self
.tool_executor
.execute_tool_call(&call)
.instrument(tracing::info_span!("tool_exec", tool_name = %tc.name))
.await;
// Process results sequentially (metrics, channel sends, message parts)
let mut result_parts: Vec<MessagePart> = Vec::new();
for (tc, tool_result) in tool_calls.iter().zip(tool_results) {
let (output, is_error, inline_stats) = match tool_result {
Ok(Some(out)) => {
if let Some(ref fs) = out.filter_stats {
Expand Down Expand Up @@ -735,3 +760,179 @@ fn tool_def_to_definition(def: &zeph_tools::registry::ToolDef) -> ToolDefinition
parameters: serde_json::to_value(&def.schema).unwrap_or_default(),
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};

use futures::future::join_all;
use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};

struct DelayExecutor {
delay: Duration,
call_order: Arc<AtomicUsize>,
}

impl ToolExecutor for DelayExecutor {
fn execute(
&self,
_response: &str,
) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
std::future::ready(Ok(None))
}

fn execute_tool_call(
&self,
call: &ToolCall,
) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
let delay = self.delay;
let order = self.call_order.clone();
let idx = order.fetch_add(1, Ordering::SeqCst);
let tool_id = call.tool_id.clone();
async move {
tokio::time::sleep(delay).await;
Ok(Some(ToolOutput {
tool_name: tool_id,
summary: format!("result-{idx}"),
blocks_executed: 1,
diff: None,
filter_stats: None,
}))
}
}
}

struct FailingNthExecutor {
fail_index: usize,
call_count: AtomicUsize,
}

impl ToolExecutor for FailingNthExecutor {
fn execute(
&self,
_response: &str,
) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
std::future::ready(Ok(None))
}

fn execute_tool_call(
&self,
call: &ToolCall,
) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
let fail = idx == self.fail_index;
let tool_id = call.tool_id.clone();
async move {
if fail {
Err(ToolError::Execution(std::io::Error::new(
std::io::ErrorKind::Other,
format!("tool {tool_id} failed"),
)))
} else {
Ok(Some(ToolOutput {
tool_name: tool_id,
summary: format!("ok-{idx}"),
blocks_executed: 1,
diff: None,
filter_stats: None,
}))
}
}
}
}

fn make_calls(n: usize) -> Vec<ToolCall> {
(0..n)
.map(|i| ToolCall {
tool_id: format!("tool-{i}"),
params: HashMap::new(),
})
.collect()
}

#[tokio::test]
async fn parallel_preserves_result_order() {
let executor = DelayExecutor {
delay: Duration::from_millis(10),
call_order: Arc::new(AtomicUsize::new(0)),
};
let calls = make_calls(5);

let futs: Vec<_> = calls
.iter()
.map(|c| executor.execute_tool_call(c))
.collect();
let results = join_all(futs).await;

for (i, r) in results.iter().enumerate() {
let out = r.as_ref().unwrap().as_ref().unwrap();
assert_eq!(out.tool_name, format!("tool-{i}"));
}
}

#[tokio::test]
async fn parallel_faster_than_sequential() {
let executor = DelayExecutor {
delay: Duration::from_millis(50),
call_order: Arc::new(AtomicUsize::new(0)),
};
let calls = make_calls(4);

let start = Instant::now();
let futs: Vec<_> = calls
.iter()
.map(|c| executor.execute_tool_call(c))
.collect();
let _results = join_all(futs).await;
let parallel_time = start.elapsed();

// Sequential would take >= 200ms (4 * 50ms); parallel should be ~50ms
assert!(
parallel_time < Duration::from_millis(150),
"parallel took {parallel_time:?}, expected < 150ms"
);
}

#[tokio::test]
async fn one_failure_does_not_block_others() {
let executor = FailingNthExecutor {
fail_index: 1,
call_count: AtomicUsize::new(0),
};
let calls = make_calls(3);

let futs: Vec<_> = calls
.iter()
.map(|c| executor.execute_tool_call(c))
.collect();
let results = join_all(futs).await;

assert!(results[0].is_ok());
assert!(results[1].is_err());
assert!(results[2].is_ok());
}

#[tokio::test]
async fn buffered_preserves_order() {
use futures::StreamExt;

let executor = DelayExecutor {
delay: Duration::from_millis(10),
call_order: Arc::new(AtomicUsize::new(0)),
};
let calls = make_calls(6);
let max_parallel = 2;

let stream = futures::stream::iter(calls.iter().map(|c| executor.execute_tool_call(c)));
let results: Vec<_> =
futures::StreamExt::collect::<Vec<_>>(stream.buffered(max_parallel)).await;

for (i, r) in results.iter().enumerate() {
let out = r.as_ref().unwrap().as_ref().unwrap();
assert_eq!(out.tool_name, format!("tool-{i}"));
}
}
}
7 changes: 7 additions & 0 deletions crates/zeph-core/src/config/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,10 @@ fn default_a2a_timeout() -> u64 {
30
}

fn default_max_parallel_tools() -> usize {
8
}

#[derive(Debug, Clone, Copy, Deserialize)]
pub struct SecurityConfig {
#[serde(default = "default_true")]
Expand All @@ -699,6 +703,8 @@ pub struct TimeoutConfig {
pub embedding_seconds: u64,
#[serde(default = "default_a2a_timeout")]
pub a2a_seconds: u64,
#[serde(default = "default_max_parallel_tools")]
pub max_parallel_tools: usize,
}

impl Default for TimeoutConfig {
Expand All @@ -707,6 +713,7 @@ impl Default for TimeoutConfig {
llm_seconds: default_llm_timeout(),
embedding_seconds: default_embedding_timeout(),
a2a_seconds: default_a2a_timeout(),
max_parallel_tools: default_max_parallel_tools(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ async fn agent_with_security_config() {
llm_seconds: 60,
embedding_seconds: 15,
a2a_seconds: 10,
max_parallel_tools: 8,
};

let mut agent = Agent::new(
Expand Down
Loading