diff --git a/CHANGELOG.md b/CHANGELOG.md index 6159615..8dc76ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Blocked command matching extracts basename from absolute paths (`/usr/bin/sudo` now correctly blocked) - Transparent wrapper commands (`env`, `command`, `exec`, `nice`, `nohup`, `time`, `xargs`) are skipped to detect the actual command - Default confirm patterns now include `$(` and backtick subshell expressions +- Enable SQLite WAL mode with SYNCHRONOUS=NORMAL for 2-5x write throughput (#639) +- Replace O(n*iterations) token scan with cached_prompt_tokens in budget checks (#640) +- Defer maybe_redact to stream completion boundary instead of per-chunk (#641) +- Replace format_tool_output string allocation with Write-into-buffer (#642) +- Change ToolCall.params from HashMap to serde_json::Map, eliminating clone (#643) +- Pre-join static system prompt sections into LazyLock (#644) +- Replace doom-loop string history with content hash comparison (#645) +- Return &'static str from detect_image_mime with case-insensitive matching (#646) +- Replace block_on in history persist with fire-and-forget async spawn (#647) ### Fixed - False positive: "sudoku" no longer matched by "sudo" blocked pattern (word-boundary matching) diff --git a/crates/zeph-core/src/agent/mod.rs b/crates/zeph-core/src/agent/mod.rs index 7576833..732544e 100644 --- a/crates/zeph-core/src/agent/mod.rs +++ b/crates/zeph-core/src/agent/mod.rs @@ -51,22 +51,34 @@ const MAX_AUDIO_BYTES: usize = 25 * 1024 * 1024; const MAX_IMAGE_BYTES: usize = 20 * 1024 * 1024; fn format_tool_output(tool_name: &str, body: &str) -> String { - format!("[tool output: {tool_name}]\n```\n{body}{TOOL_OUTPUT_SUFFIX}") + use std::fmt::Write; + let capacity = "[tool output: ".len() + + tool_name.len() + + "]\n```\n".len() + + body.len() + + TOOL_OUTPUT_SUFFIX.len(); + let mut buf = String::with_capacity(capacity); + let _ = write!( + buf, + "[tool output: {tool_name}]\n```\n{body}{TOOL_OUTPUT_SUFFIX}" + ); + buf } -fn detect_image_mime(filename: Option<&str>) -> String { +fn detect_image_mime(filename: Option<&str>) -> &'static str { let ext = filename .and_then(|f| std::path::Path::new(f).extension()) .and_then(|e| e.to_str()) - .unwrap_or("") - .to_lowercase(); - match ext.as_str() { - "jpg" | "jpeg" => "image/jpeg", - "gif" => "image/gif", - "webp" => "image/webp", - _ => "image/png", - } - .to_owned() + .unwrap_or(""); + if ext.eq_ignore_ascii_case("jpg") || ext.eq_ignore_ascii_case("jpeg") { + "image/jpeg" + } else if ext.eq_ignore_ascii_case("gif") { + "image/gif" + } else if ext.eq_ignore_ascii_case("webp") { + "image/webp" + } else { + "image/png" + } } struct QueuedMessage { @@ -153,7 +165,7 @@ pub struct Agent { message_queue: VecDeque, summary_provider: Option, warmup_ready: Option>, - doom_loop_history: Vec, + doom_loop_history: Vec, cost_tracker: Option, cached_prompt_tokens: u64, stt: Option>, @@ -787,7 +799,7 @@ impl Agent { ); continue; } - let mime_type = detect_image_mime(attachment.filename.as_deref()); + let mime_type = detect_image_mime(attachment.filename.as_deref()).to_string(); image_parts.push(MessagePart::Image { data: attachment.data, mime_type, @@ -918,7 +930,7 @@ impl Agent { .await?; return Ok(()); } - let mime_type = detect_image_mime(Some(path)); + let mime_type = detect_image_mime(Some(path)).to_string(); extra_parts.push(MessagePart::Image { data, mime_type }); self.channel .send(&format!("Image loaded: {path}. Send your message.")) @@ -2190,23 +2202,60 @@ pub(super) mod agent_tests { #[test] fn doom_loop_detection_triggers_on_identical_outputs() { - let s = "same output".to_owned(); - let history = vec![s.clone(), s.clone(), s]; + // doom_loop_history stores u64 hashes — identical content produces equal hashes + let h = 42u64; + let history: Vec = vec![h, h, h]; let recent = &history[history.len() - DOOM_LOOP_WINDOW..]; assert!(recent.windows(2).all(|w| w[0] == w[1])); } #[test] fn doom_loop_detection_no_trigger_on_different_outputs() { - let history = vec![ - "output a".to_owned(), - "output b".to_owned(), - "output c".to_owned(), - ]; + let history: Vec = vec![1, 2, 3]; let recent = &history[history.len() - DOOM_LOOP_WINDOW..]; assert!(!recent.windows(2).all(|w| w[0] == w[1])); } + #[test] + fn format_tool_output_structure() { + let out = format_tool_output("bash", "hello world"); + assert!(out.starts_with("[tool output: bash]\n```\n")); + assert!(out.ends_with(TOOL_OUTPUT_SUFFIX)); + assert!(out.contains("hello world")); + } + + #[test] + fn format_tool_output_empty_body() { + let out = format_tool_output("grep", ""); + assert_eq!(out, "[tool output: grep]\n```\n\n```"); + } + + #[test] + fn detect_image_mime_standard() { + assert_eq!(detect_image_mime(Some("photo.jpg")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("photo.jpeg")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("anim.gif")), "image/gif"); + assert_eq!(detect_image_mime(Some("img.webp")), "image/webp"); + assert_eq!(detect_image_mime(Some("img.png")), "image/png"); + assert_eq!(detect_image_mime(None), "image/png"); + } + + #[test] + fn detect_image_mime_uppercase() { + assert_eq!(detect_image_mime(Some("photo.JPG")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("photo.JPEG")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("anim.GIF")), "image/gif"); + assert_eq!(detect_image_mime(Some("img.WEBP")), "image/webp"); + } + + #[test] + fn detect_image_mime_mixed_case() { + assert_eq!(detect_image_mime(Some("photo.Jpg")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("photo.JpEg")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("anim.Gif")), "image/gif"); + assert_eq!(detect_image_mime(Some("img.WebP")), "image/webp"); + } + #[tokio::test] async fn cancel_signal_propagates_to_fresh_token() { use tokio_util::sync::CancellationToken; diff --git a/crates/zeph-core/src/agent/streaming.rs b/crates/zeph-core/src/agent/streaming.rs index 50b13c7..d05d60e 100644 --- a/crates/zeph-core/src/agent/streaming.rs +++ b/crates/zeph-core/src/agent/streaming.rs @@ -2,15 +2,23 @@ use tokio_stream::StreamExt; use zeph_llm::provider::{ChatResponse, LlmProvider, Message, MessagePart, Role, ToolDefinition}; use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput}; +use super::{Agent, DOOM_LOOP_WINDOW, TOOL_LOOP_KEEP_RECENT, format_tool_output}; use crate::channel::Channel; use crate::redact::redact_secrets; -use zeph_memory::semantic::estimate_tokens; - -use super::{Agent, DOOM_LOOP_WINDOW, TOOL_LOOP_KEEP_RECENT, format_tool_output}; use tracing::Instrument; /// Strip volatile IDs from message content so doom-loop comparison is stable. /// Normalizes `[tool_result: ]` and `[tool_use: ()]` by removing unique IDs. +// DefaultHasher output is not stable across Rust versions — do not persist or serialize +// these hashes. They are used only for within-session equality comparison. +fn doom_loop_hash(content: &str) -> u64 { + use std::hash::{DefaultHasher, Hash, Hasher}; + let normalized = normalize_for_doom_loop(content); + let mut hasher = DefaultHasher::new(); + normalized.hash(&mut hasher); + hasher.finish() +} + fn normalize_for_doom_loop(content: &str) -> String { let mut out = String::with_capacity(content.len()); let mut rest = content; @@ -86,11 +94,7 @@ impl Agent { // Context budget check at 80% threshold if let Some(ref budget) = self.context_state.budget { - let used: usize = self - .messages - .iter() - .map(|m| estimate_tokens(&m.content)) - .sum(); + let used = usize::try_from(self.cached_prompt_tokens).unwrap_or(usize::MAX); let threshold = budget.max_tokens() * 4 / 5; if used >= threshold { tracing::warn!( @@ -147,10 +151,10 @@ impl Agent { // Prune tool output bodies from older iterations to reduce context growth self.prune_stale_tool_outputs(TOOL_LOOP_KEEP_RECENT); - // Doom-loop detection: compare last N outputs by string equality + // Doom-loop detection: compare last N outputs by content hash if let Some(last_msg) = self.messages.last() { self.doom_loop_history - .push(normalize_for_doom_loop(&last_msg.content)); + .push(doom_loop_hash(&last_msg.content)); if self.doom_loop_history.len() >= DOOM_LOOP_WINDOW { let recent = &self.doom_loop_history[self.doom_loop_history.len() - DOOM_LOOP_WINDOW..]; @@ -218,7 +222,12 @@ impl Agent { }); self.record_cache_usage(); self.record_cost(prompt_estimate, completion_estimate_for_cost); - Ok(Some(r?)) + let raw = r?; + // Redact secrets from the full response before it is persisted to history. + // Streaming chunks were already sent to the channel without per-chunk redaction + // (acceptable trade-off: ephemeral display vs allocation per chunk). + let redacted = self.maybe_redact(&raw).into_owned(); + Ok(Some(redacted)) } else { self.channel .send("LLM request timed out. Please try again.") @@ -492,8 +501,7 @@ impl Agent { }; let chunk: String = chunk_result?; response.push_str(&chunk); - let display = self.maybe_redact(&chunk); - self.channel.send_chunk(&display).await?; + self.channel.send_chunk(&chunk).await?; } self.channel.flush_chunks().await?; @@ -549,11 +557,7 @@ impl Agent { self.channel.send_typing().await?; if let Some(ref budget) = self.context_state.budget { - let used: usize = self - .messages - .iter() - .map(|m| estimate_tokens(&m.content)) - .sum(); + let used = usize::try_from(self.cached_prompt_tokens).unwrap_or(usize::MAX); let threshold = budget.max_tokens() * 4 / 5; if used >= threshold { tracing::warn!( @@ -716,11 +720,11 @@ impl Agent { let calls: Vec = tool_calls .iter() .map(|tc| { - let params: std::collections::HashMap = + let params: serde_json::Map = if let serde_json::Value::Object(map) = &tc.input { - map.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + map.clone() } else { - std::collections::HashMap::new() + serde_json::Map::new() }; ToolCall { tool_id: tc.name.clone(), @@ -844,7 +848,7 @@ impl Agent { ) -> Result { if let Some(last_msg) = self.messages.last() { self.doom_loop_history - .push(normalize_for_doom_loop(&last_msg.content)); + .push(doom_loop_hash(&last_msg.content)); if self.doom_loop_history.len() >= DOOM_LOOP_WINDOW { let recent = &self.doom_loop_history[self.doom_loop_history.len() - DOOM_LOOP_WINDOW..]; @@ -879,7 +883,6 @@ fn tool_def_to_definition(def: &zeph_tools::registry::ToolDef) -> ToolDefinition #[cfg(test)] mod tests { - use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::{Duration, Instant}; @@ -1044,7 +1047,7 @@ mod tests { (0..n) .map(|i| ToolCall { tool_id: format!("tool-{i}"), - params: HashMap::new(), + params: serde_json::Map::new(), }) .collect() } diff --git a/crates/zeph-core/src/context.rs b/crates/zeph-core/src/context.rs index 401665f..22b829c 100644 --- a/crates/zeph-core/src/context.rs +++ b/crates/zeph-core/src/context.rs @@ -1,3 +1,5 @@ +use std::sync::LazyLock; + use zeph_memory::semantic::estimate_tokens; const BASE_PROMPT_HEADER: &str = "\ @@ -44,6 +46,26 @@ the user explicitly asks about a skill by name.\n\ - Do not force-push to main/master branches.\n\ - Do not execute commands that could cause data loss without confirmation."; +static PROMPT_LEGACY: LazyLock = LazyLock::new(|| { + let mut s = String::with_capacity( + BASE_PROMPT_HEADER.len() + TOOL_USE_LEGACY.len() + BASE_PROMPT_TAIL.len(), + ); + s.push_str(BASE_PROMPT_HEADER); + s.push_str(TOOL_USE_LEGACY); + s.push_str(BASE_PROMPT_TAIL); + s +}); + +static PROMPT_NATIVE: LazyLock = LazyLock::new(|| { + let mut s = String::with_capacity( + BASE_PROMPT_HEADER.len() + TOOL_USE_NATIVE.len() + BASE_PROMPT_TAIL.len(), + ); + s.push_str(BASE_PROMPT_HEADER); + s.push_str(TOOL_USE_NATIVE); + s.push_str(BASE_PROMPT_TAIL); + s +}); + #[must_use] pub fn build_system_prompt( skills_prompt: &str, @@ -51,15 +73,20 @@ pub fn build_system_prompt( tool_catalog: Option<&str>, native_tools: bool, ) -> String { - let mut prompt = BASE_PROMPT_HEADER.to_string(); - - if native_tools { - prompt.push_str(TOOL_USE_NATIVE); + let base = if native_tools { + &*PROMPT_NATIVE } else { - prompt.push_str(TOOL_USE_LEGACY); - } - - prompt.push_str(BASE_PROMPT_TAIL); + &*PROMPT_LEGACY + }; + let dynamic_len = env.map_or(0, |e| e.format().len() + 2) + + tool_catalog.map_or(0, |c| if c.is_empty() { 0 } else { c.len() + 2 }) + + if skills_prompt.is_empty() { + 0 + } else { + skills_prompt.len() + 2 + }; + let mut prompt = String::with_capacity(base.len() + dynamic_len); + prompt.push_str(base); if let Some(env) = env { prompt.push_str("\n\n"); diff --git a/crates/zeph-memory/src/sqlite/mod.rs b/crates/zeph-memory/src/sqlite/mod.rs index e7402fa..44f0961 100644 --- a/crates/zeph-memory/src/sqlite/mod.rs +++ b/crates/zeph-memory/src/sqlite/mod.rs @@ -37,7 +37,9 @@ impl SqliteStore { let opts = SqliteConnectOptions::from_str(&url)? .create_if_missing(true) - .foreign_keys(true); + .foreign_keys(true) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) + .synchronous(sqlx::sqlite::SqliteSynchronous::Normal); let pool = SqlitePoolOptions::new() .max_connections(5) @@ -65,3 +67,24 @@ impl SqliteStore { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[tokio::test] + async fn wal_journal_mode_enabled_on_file_db() { + let file = NamedTempFile::new().expect("tempfile"); + let path = file.path().to_str().expect("valid path"); + + let store = SqliteStore::new(path).await.expect("SqliteStore::new"); + + let mode: String = sqlx::query_scalar("PRAGMA journal_mode") + .fetch_one(store.pool()) + .await + .expect("PRAGMA query"); + + assert_eq!(mode, "wal", "expected WAL journal mode, got: {mode}"); + } +} diff --git a/crates/zeph-tools/src/composite.rs b/crates/zeph-tools/src/composite.rs index 217e862..8ae5e2e 100644 --- a/crates/zeph-tools/src/composite.rs +++ b/crates/zeph-tools/src/composite.rs @@ -210,7 +210,7 @@ mod tests { let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor); let call = ToolCall { tool_id: "read".to_owned(), - params: std::collections::HashMap::new(), + params: serde_json::Map::new(), }; let result = composite.execute_tool_call(&call).await.unwrap().unwrap(); assert_eq!(result.summary, "file_handler"); @@ -221,7 +221,7 @@ mod tests { let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor); let call = ToolCall { tool_id: "bash".to_owned(), - params: std::collections::HashMap::new(), + params: serde_json::Map::new(), }; let result = composite.execute_tool_call(&call).await.unwrap().unwrap(); assert_eq!(result.summary, "shell_handler"); @@ -232,7 +232,7 @@ mod tests { let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor); let call = ToolCall { tool_id: "unknown".to_owned(), - params: std::collections::HashMap::new(), + params: serde_json::Map::new(), }; let result = composite.execute_tool_call(&call).await.unwrap(); assert!(result.is_none()); diff --git a/crates/zeph-tools/src/executor.rs b/crates/zeph-tools/src/executor.rs index 00095e3..47c46e6 100644 --- a/crates/zeph-tools/src/executor.rs +++ b/crates/zeph-tools/src/executor.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::fmt; /// Data for rendering file diffs in the TUI. @@ -13,7 +12,7 @@ pub struct DiffData { #[derive(Debug, Clone)] pub struct ToolCall { pub tool_id: String, - pub params: HashMap, + pub params: serde_json::Map, } /// Cumulative filter statistics for a single tool execution. @@ -153,16 +152,15 @@ pub enum ToolError { Execution(#[from] std::io::Error), } -/// Deserialize tool call params from a `HashMap` into a typed struct. +/// Deserialize tool call params from a `serde_json::Map` into a typed struct. /// /// # Errors /// /// Returns `ToolError::InvalidParams` when deserialization fails. -pub fn deserialize_params( - params: &HashMap, +pub fn deserialize_params( + params: &serde_json::Map, ) -> Result { - let obj = - serde_json::Value::Object(params.iter().map(|(k, v)| (k.clone(), v.clone())).collect()); + let obj = serde_json::Value::Object(params.clone()); serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams { message: e.to_string(), }) @@ -292,7 +290,7 @@ mod tests { name: String, count: u32, } - let mut map = HashMap::new(); + let mut map = serde_json::Map::new(); map.insert("name".to_owned(), serde_json::json!("test")); map.insert("count".to_owned(), serde_json::json!(42)); let p: P = deserialize_params(&map).unwrap(); @@ -312,8 +310,8 @@ mod tests { #[allow(dead_code)] name: String, } - let map: HashMap = HashMap::new(); - let err = deserialize_params::(&map).unwrap_err(); + let map = serde_json::Map::new(); + let err = deserialize_params::

(&map).unwrap_err(); assert!(matches!(err, ToolError::InvalidParams { .. })); } @@ -324,9 +322,9 @@ mod tests { #[allow(dead_code)] count: u32, } - let mut map = HashMap::new(); + let mut map = serde_json::Map::new(); map.insert("count".to_owned(), serde_json::json!("not a number")); - let err = deserialize_params::(&map).unwrap_err(); + let err = deserialize_params::

(&map).unwrap_err(); assert!(matches!(err, ToolError::InvalidParams { .. })); } @@ -336,7 +334,7 @@ mod tests { struct P { name: Option, } - let map: HashMap = HashMap::new(); + let map = serde_json::Map::new(); let p: P = deserialize_params(&map).unwrap(); assert_eq!(p, P { name: None }); } @@ -347,7 +345,7 @@ mod tests { struct P { name: String, } - let mut map = HashMap::new(); + let mut map = serde_json::Map::new(); map.insert("name".to_owned(), serde_json::json!("test")); map.insert("extra".to_owned(), serde_json::json!(true)); let p: P = deserialize_params(&map).unwrap(); @@ -408,7 +406,7 @@ mod tests { let exec = DefaultExecutor; let call = ToolCall { tool_id: "anything".to_owned(), - params: std::collections::HashMap::new(), + params: serde_json::Map::new(), }; let result = exec.execute_tool_call(&call).await.unwrap(); assert!(result.is_none()); diff --git a/crates/zeph-tools/src/file.rs b/crates/zeph-tools/src/file.rs index cda61d0..2b3bee0 100644 --- a/crates/zeph-tools/src/file.rs +++ b/crates/zeph-tools/src/file.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::path::{Path, PathBuf}; use schemars::JsonSchema; @@ -100,7 +99,7 @@ impl FileExecutor { pub fn execute_file_tool( &self, tool_id: &str, - params: &HashMap, + params: &serde_json::Map, ) -> Result, ToolError> { match tool_id { "read" => { @@ -386,7 +385,9 @@ mod tests { tempfile::tempdir().unwrap() } - fn make_params(pairs: &[(&str, serde_json::Value)]) -> HashMap { + fn make_params( + pairs: &[(&str, serde_json::Value)], + ) -> serde_json::Map { pairs .iter() .map(|(k, v)| ((*k).to_owned(), v.clone())) @@ -486,7 +487,7 @@ mod tests { #[test] fn unknown_tool_returns_none() { let exec = FileExecutor::new(vec![]); - let params = HashMap::new(); + let params = serde_json::Map::new(); let result = exec.execute_file_tool("unknown", ¶ms).unwrap(); assert!(result.is_none()); } @@ -625,7 +626,7 @@ mod tests { fn missing_required_path_returns_invalid_params() { let dir = temp_dir(); let exec = FileExecutor::new(vec![dir.path().to_path_buf()]); - let params = HashMap::new(); + let params = serde_json::Map::new(); let result = exec.execute_file_tool("read", ¶ms); assert!(matches!(result, Err(ToolError::InvalidParams { .. }))); } diff --git a/crates/zeph-tools/src/shell.rs b/crates/zeph-tools/src/shell.rs index 642bdcc..9de1272 100644 --- a/crates/zeph-tools/src/shell.rs +++ b/crates/zeph-tools/src/shell.rs @@ -1730,7 +1730,7 @@ mod tests { let executor = ShellExecutor::new(&default_config()); let call = ToolCall { tool_id: "bash".to_owned(), - params: std::collections::HashMap::new(), + params: serde_json::Map::new(), }; let result = executor.execute_tool_call(&call).await; assert!(matches!(result, Err(ToolError::InvalidParams { .. }))); diff --git a/crates/zeph-tools/src/trust_gate.rs b/crates/zeph-tools/src/trust_gate.rs index e9d1f7b..b59c47c 100644 --- a/crates/zeph-tools/src/trust_gate.rs +++ b/crates/zeph-tools/src/trust_gate.rs @@ -102,7 +102,6 @@ impl ToolExecutor for TrustGateExecutor { #[cfg(test)] mod tests { use super::*; - use std::collections::HashMap; #[derive(Debug)] struct MockExecutor; @@ -128,12 +127,12 @@ mod tests { fn make_call(tool_id: &str) -> ToolCall { ToolCall { tool_id: tool_id.into(), - params: HashMap::new(), + params: serde_json::Map::new(), } } fn make_call_with_cmd(tool_id: &str, cmd: &str) -> ToolCall { - let mut params = HashMap::new(); + let mut params = serde_json::Map::new(); params.insert("command".into(), serde_json::Value::String(cmd.into())); ToolCall { tool_id: tool_id.into(), diff --git a/src/main.rs b/src/main.rs index 74e8fda..e374c0f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -257,10 +257,12 @@ async fn main() -> anyhow::Result<()> { let persist: Box = Box::new(move |text: &str| { let store = store.clone(); let text = text.to_owned(); - if let Ok(handle) = tokio::runtime::Handle::try_current() - && let Err(e) = handle.block_on(store.save_input_entry(&text)) - { - tracing::warn!("failed to persist input history entry: {e}"); + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + if let Err(e) = store.save_input_entry(&text).await { + tracing::warn!("failed to persist input history entry: {e}"); + } + }); } }); Some((entries, persist))