diff --git a/crates/zeph-tools/src/executor.rs b/crates/zeph-tools/src/executor.rs index 04441e6..00095e3 100644 --- a/crates/zeph-tools/src/executor.rs +++ b/crates/zeph-tools/src/executor.rs @@ -146,10 +146,28 @@ pub enum ToolError { #[error("operation cancelled")] Cancelled, + #[error("invalid tool parameters: {message}")] + InvalidParams { message: String }, + #[error("execution failed: {0}")] Execution(#[from] std::io::Error), } +/// Deserialize tool call params from a `HashMap` into a typed struct. +/// +/// # Errors +/// +/// Returns `ToolError::InvalidParams` when deserialization fails. +pub fn deserialize_params( + params: &HashMap, +) -> Result { + let obj = + serde_json::Value::Object(params.iter().map(|(k, v)| (k.clone(), v.clone())).collect()); + serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams { + message: e.to_string(), + }) +} + /// Async trait for tool execution backends (shell, future MCP, A2A). /// /// Accepts the full LLM response and returns an optional output. @@ -256,6 +274,91 @@ mod tests { assert_eq!(err.to_string(), "command timed out after 30s"); } + #[test] + fn tool_error_invalid_params_display() { + let err = ToolError::InvalidParams { + message: "missing field `command`".to_owned(), + }; + assert_eq!( + err.to_string(), + "invalid tool parameters: missing field `command`" + ); + } + + #[test] + fn deserialize_params_valid() { + #[derive(Debug, serde::Deserialize, PartialEq)] + struct P { + name: String, + count: u32, + } + let mut map = HashMap::new(); + map.insert("name".to_owned(), serde_json::json!("test")); + map.insert("count".to_owned(), serde_json::json!(42)); + let p: P = deserialize_params(&map).unwrap(); + assert_eq!( + p, + P { + name: "test".to_owned(), + count: 42 + } + ); + } + + #[test] + fn deserialize_params_missing_required_field() { + #[derive(Debug, serde::Deserialize)] + struct P { + #[allow(dead_code)] + name: String, + } + let map: HashMap = HashMap::new(); + let err = deserialize_params::(&map).unwrap_err(); + assert!(matches!(err, ToolError::InvalidParams { .. })); + } + + #[test] + fn deserialize_params_wrong_type() { + #[derive(Debug, serde::Deserialize)] + struct P { + #[allow(dead_code)] + count: u32, + } + let mut map = HashMap::new(); + map.insert("count".to_owned(), serde_json::json!("not a number")); + let err = deserialize_params::(&map).unwrap_err(); + assert!(matches!(err, ToolError::InvalidParams { .. })); + } + + #[test] + fn deserialize_params_all_optional_empty() { + #[derive(Debug, serde::Deserialize, PartialEq)] + struct P { + name: Option, + } + let map: HashMap = HashMap::new(); + let p: P = deserialize_params(&map).unwrap(); + assert_eq!(p, P { name: None }); + } + + #[test] + fn deserialize_params_ignores_extra_fields() { + #[derive(Debug, serde::Deserialize, PartialEq)] + struct P { + name: String, + } + let mut map = HashMap::new(); + map.insert("name".to_owned(), serde_json::json!("test")); + map.insert("extra".to_owned(), serde_json::json!(true)); + let p: P = deserialize_params(&map).unwrap(); + assert_eq!( + p, + P { + name: "test".to_owned() + } + ); + } + #[test] fn tool_error_execution_display() { let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found"); diff --git a/crates/zeph-tools/src/file.rs b/crates/zeph-tools/src/file.rs index 0e49224..cda61d0 100644 --- a/crates/zeph-tools/src/file.rs +++ b/crates/zeph-tools/src/file.rs @@ -2,13 +2,14 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use schemars::JsonSchema; +use serde::Deserialize; -use crate::executor::{DiffData, ToolCall, ToolError, ToolExecutor, ToolOutput}; +use crate::executor::{ + DiffData, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params, +}; use crate::registry::{InvocationHint, ToolDef}; -// Schema-only: fields are read by schemars derive, not by Rust code directly. -#[derive(JsonSchema)] -#[allow(dead_code)] +#[derive(Deserialize, JsonSchema)] pub(crate) struct ReadParams { /// File path path: String, @@ -18,9 +19,7 @@ pub(crate) struct ReadParams { limit: Option, } -// Schema-only: fields are read by schemars derive, not by Rust code directly. -#[derive(JsonSchema)] -#[allow(dead_code)] +#[derive(Deserialize, JsonSchema)] struct WriteParams { /// File path path: String, @@ -28,9 +27,7 @@ struct WriteParams { content: String, } -// Schema-only: fields are read by schemars derive, not by Rust code directly. -#[derive(JsonSchema)] -#[allow(dead_code)] +#[derive(Deserialize, JsonSchema)] struct EditParams { /// File path path: String, @@ -40,17 +37,13 @@ struct EditParams { new_string: String, } -// Schema-only: fields are read by schemars derive, not by Rust code directly. -#[derive(JsonSchema)] -#[allow(dead_code)] +#[derive(Deserialize, JsonSchema)] struct GlobParams { /// Glob pattern pattern: String, } -// Schema-only: fields are read by schemars derive, not by Rust code directly. -#[derive(JsonSchema)] -#[allow(dead_code)] +#[derive(Deserialize, JsonSchema)] struct GrepParams { /// Regex pattern pattern: String, @@ -110,26 +103,36 @@ impl FileExecutor { params: &HashMap, ) -> Result, ToolError> { match tool_id { - "read" => self.handle_read(params), - "write" => self.handle_write(params), - "edit" => self.handle_edit(params), - "glob" => self.handle_glob(params), - "grep" => self.handle_grep(params), + "read" => { + let p: ReadParams = deserialize_params(params)?; + self.handle_read(&p) + } + "write" => { + let p: WriteParams = deserialize_params(params)?; + self.handle_write(&p) + } + "edit" => { + let p: EditParams = deserialize_params(params)?; + self.handle_edit(&p) + } + "glob" => { + let p: GlobParams = deserialize_params(params)?; + self.handle_glob(&p) + } + "grep" => { + let p: GrepParams = deserialize_params(params)?; + self.handle_grep(&p) + } _ => Ok(None), } } - fn handle_read( - &self, - params: &HashMap, - ) -> Result, ToolError> { - let path_str = param_str(params, "path")?; - let path = self.validate_path(Path::new(&path_str))?; - + fn handle_read(&self, params: &ReadParams) -> Result, ToolError> { + let path = self.validate_path(Path::new(¶ms.path))?; let content = std::fs::read_to_string(&path)?; - let offset = param_usize(params, "offset").unwrap_or(0); - let limit = param_usize(params, "limit").unwrap_or(usize::MAX); + let offset = params.offset.unwrap_or(0) as usize; + let limit = params.limit.map_or(usize::MAX, |l| l as usize); let selected: Vec = content .lines() @@ -149,62 +152,50 @@ impl FileExecutor { })) } - fn handle_write( - &self, - params: &HashMap, - ) -> Result, ToolError> { - let path_str = param_str(params, "path")?; - let content = param_str(params, "content")?; - let path = self.validate_path(Path::new(&path_str))?; - + fn handle_write(&self, params: &WriteParams) -> Result, ToolError> { + let path = self.validate_path(Path::new(¶ms.path))?; let old_content = std::fs::read_to_string(&path).unwrap_or_default(); if let Some(parent) = path.parent() { std::fs::create_dir_all(parent)?; } - std::fs::write(&path, &content)?; + std::fs::write(&path, ¶ms.content)?; Ok(Some(ToolOutput { tool_name: "write".to_owned(), - summary: format!("Wrote {} bytes to {path_str}", content.len()), + summary: format!("Wrote {} bytes to {}", params.content.len(), params.path), blocks_executed: 1, filter_stats: None, diff: Some(DiffData { - file_path: path_str, + file_path: params.path.clone(), old_content, - new_content: content, + new_content: params.content.clone(), }), streamed: false, })) } - fn handle_edit( - &self, - params: &HashMap, - ) -> Result, ToolError> { - let path_str = param_str(params, "path")?; - let old_string = param_str(params, "old_string")?; - let new_string = param_str(params, "new_string")?; - let path = self.validate_path(Path::new(&path_str))?; - + fn handle_edit(&self, params: &EditParams) -> Result, ToolError> { + let path = self.validate_path(Path::new(¶ms.path))?; let content = std::fs::read_to_string(&path)?; - if !content.contains(&old_string) { + + if !content.contains(¶ms.old_string) { return Err(ToolError::Execution(std::io::Error::new( std::io::ErrorKind::NotFound, - format!("old_string not found in {path_str}"), + format!("old_string not found in {}", params.path), ))); } - let new_content = content.replacen(&old_string, &new_string, 1); + let new_content = content.replacen(¶ms.old_string, ¶ms.new_string, 1); std::fs::write(&path, &new_content)?; Ok(Some(ToolOutput { tool_name: "edit".to_owned(), - summary: format!("Edited {path_str}"), + summary: format!("Edited {}", params.path), blocks_executed: 1, filter_stats: None, diff: Some(DiffData { - file_path: path_str, + file_path: params.path.clone(), old_content: content, new_content, }), @@ -212,12 +203,8 @@ impl FileExecutor { })) } - fn handle_glob( - &self, - params: &HashMap, - ) -> Result, ToolError> { - let pattern = param_str(params, "pattern")?; - let matches: Vec = glob::glob(&pattern) + fn handle_glob(&self, params: &GlobParams) -> Result, ToolError> { + let matches: Vec = glob::glob(¶ms.pattern) .map_err(|e| { ToolError::Execution(std::io::Error::new( std::io::ErrorKind::InvalidInput, @@ -235,7 +222,7 @@ impl FileExecutor { Ok(Some(ToolOutput { tool_name: "glob".to_owned(), summary: if matches.is_empty() { - format!("No files matching: {pattern}") + format!("No files matching: {}", params.pattern) } else { matches.join("\n") }, @@ -246,23 +233,15 @@ impl FileExecutor { })) } - fn handle_grep( - &self, - params: &HashMap, - ) -> Result, ToolError> { - let pattern = param_str(params, "pattern")?; - let search_path = params.get("path").and_then(|v| v.as_str()).unwrap_or("."); - let case_sensitive = params - .get("case_sensitive") - .and_then(serde_json::Value::as_bool) - .unwrap_or(true); - + fn handle_grep(&self, params: &GrepParams) -> Result, ToolError> { + let search_path = params.path.as_deref().unwrap_or("."); + let case_sensitive = params.case_sensitive.unwrap_or(true); let path = self.validate_path(Path::new(search_path))?; let regex = if case_sensitive { - regex::Regex::new(&pattern) + regex::Regex::new(¶ms.pattern) } else { - regex::RegexBuilder::new(&pattern) + regex::RegexBuilder::new(¶ms.pattern) .case_insensitive(true) .build() } @@ -279,7 +258,7 @@ impl FileExecutor { Ok(Some(ToolOutput { tool_name: "grep".to_owned(), summary: if results.is_empty() { - format!("No matches for: {pattern}") + format!("No matches for: {}", params.pattern) } else { results.join("\n") }, @@ -398,27 +377,6 @@ fn grep_recursive( Ok(()) } -fn param_str(params: &HashMap, key: &str) -> Result { - params - .get(key) - .and_then(|v| v.as_str()) - .map(str::to_owned) - .ok_or_else(|| { - ToolError::Execution(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - format!("missing required parameter: {key}"), - )) - }) -} - -fn param_usize(params: &HashMap, key: &str) -> Option { - #[allow(clippy::cast_possible_truncation)] - params - .get(key) - .and_then(serde_json::Value::as_u64) - .map(|n| n as usize) -} - #[cfg(test)] mod tests { use super::*; @@ -662,4 +620,13 @@ mod tests { assert!(props.contains_key("offset")); assert!(props.contains_key("limit")); } + + #[test] + fn missing_required_path_returns_invalid_params() { + let dir = temp_dir(); + let exec = FileExecutor::new(vec![dir.path().to_path_buf()]); + let params = HashMap::new(); + let result = exec.execute_file_tool("read", ¶ms); + assert!(matches!(result, Err(ToolError::InvalidParams { .. }))); + } } diff --git a/crates/zeph-tools/src/scrape.rs b/crates/zeph-tools/src/scrape.rs index ab85238..d0a3e2e 100644 --- a/crates/zeph-tools/src/scrape.rs +++ b/crates/zeph-tools/src/scrape.rs @@ -5,7 +5,7 @@ use serde::Deserialize; use url::Url; use crate::config::ScrapeConfig; -use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput}; +use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params}; #[derive(Debug, Deserialize, JsonSchema)] struct ScrapeInstruction { @@ -116,18 +116,7 @@ impl ToolExecutor for WebScrapeExecutor { return Ok(None); } - let instruction: ScrapeInstruction = serde_json::from_value(serde_json::Value::Object( - call.params - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(), - )) - .map_err(|e| { - ToolError::Execution(std::io::Error::new( - std::io::ErrorKind::InvalidData, - e.to_string(), - )) - })?; + let instruction: ScrapeInstruction = deserialize_params(&call.params)?; let result = self.scrape_instruction(&instruction).await?; diff --git a/crates/zeph-tools/src/shell.rs b/crates/zeph-tools/src/shell.rs index 97cefa9..b5588db 100644 --- a/crates/zeph-tools/src/shell.rs +++ b/crates/zeph-tools/src/shell.rs @@ -5,6 +5,7 @@ use tokio::process::Command; use tokio_util::sync::CancellationToken; use schemars::JsonSchema; +use serde::Deserialize; use crate::audit::{AuditEntry, AuditLogger, AuditResult}; use crate::config::ShellConfig; @@ -21,9 +22,7 @@ const DEFAULT_BLOCKED: &[&str] = &[ const NETWORK_COMMANDS: &[&str] = &["curl", "wget", "nc ", "ncat", "netcat"]; -// Schema-only: fields are read by schemars derive, not by Rust code directly. -#[derive(JsonSchema)] -#[allow(dead_code)] +#[derive(Deserialize, JsonSchema)] pub(crate) struct BashParams { /// The bash command to execute command: String, @@ -376,14 +375,11 @@ impl ToolExecutor for ShellExecutor { if call.tool_id != "bash" { return Ok(None); } - let command = call - .params - .get("command") - .and_then(|v| v.as_str()) - .unwrap_or_default(); - if command.is_empty() { + let params: BashParams = crate::executor::deserialize_params(&call.params)?; + if params.command.is_empty() { return Ok(None); } + let command = ¶ms.command; // Wrap as a fenced block so execute_inner can extract and run it let synthetic = format!("```bash\n{command}\n```"); self.execute_inner(&synthetic, false).await @@ -1358,4 +1354,42 @@ mod tests { let result = executor.execute(response).await; assert!(matches!(result, Err(ToolError::Cancelled))); } + + #[tokio::test] + #[cfg(not(target_os = "windows"))] + async fn execute_tool_call_valid_command() { + let executor = ShellExecutor::new(&default_config()); + let call = ToolCall { + tool_id: "bash".to_owned(), + params: [("command".to_owned(), serde_json::json!("echo hi"))] + .into_iter() + .collect(), + }; + let result = executor.execute_tool_call(&call).await.unwrap().unwrap(); + assert!(result.summary.contains("hi")); + } + + #[tokio::test] + async fn execute_tool_call_missing_command_returns_invalid_params() { + let executor = ShellExecutor::new(&default_config()); + let call = ToolCall { + tool_id: "bash".to_owned(), + params: std::collections::HashMap::new(), + }; + let result = executor.execute_tool_call(&call).await; + assert!(matches!(result, Err(ToolError::InvalidParams { .. }))); + } + + #[tokio::test] + async fn execute_tool_call_empty_command_returns_none() { + let executor = ShellExecutor::new(&default_config()); + let call = ToolCall { + tool_id: "bash".to_owned(), + params: [("command".to_owned(), serde_json::json!(""))] + .into_iter() + .collect(), + }; + let result = executor.execute_tool_call(&call).await.unwrap(); + assert!(result.is_none()); + } }