diff --git a/CHANGELOG.md b/CHANGELOG.md index 155802b..394e9f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Path separator rejection in MCP command validation to prevent symlink-based bypasses ### Changed +- `MessagePart::Image` variant now holds `Box` instead of inline fields, improving semantic grouping of image data +- `Agent` simplified to `Agent` — ToolExecutor generic replaced with `Box`, reducing monomorphization - Shell command detection rewritten from substring matching to tokenizer-based pipeline with escape normalization, eliminating bypass vectors via backslash insertion, hex/octal escapes, quote splitting, and pipe chains - Shell sandbox path validation now uses `std::path::absolute()` as fallback when `canonicalize()` fails on non-existent paths - Blocked command matching extracts basename from absolute paths (`/usr/bin/sudo` now correctly blocked) diff --git a/crates/zeph-core/src/agent/builder.rs b/crates/zeph-core/src/agent/builder.rs index 2299a78..fbd4bc0 100644 --- a/crates/zeph-core/src/agent/builder.rs +++ b/crates/zeph-core/src/agent/builder.rs @@ -5,6 +5,7 @@ use tokio::sync::{Notify, mpsc, watch}; use zeph_llm::any::AnyProvider; use zeph_llm::provider::LlmProvider; +use super::Agent; use crate::channel::Channel; use crate::config::{LearningConfig, SecurityConfig, TimeoutConfig}; use crate::config_watcher::ConfigEvent; @@ -13,11 +14,8 @@ use crate::cost::CostTracker; use crate::metrics::MetricsSnapshot; use zeph_memory::semantic::SemanticMemory; use zeph_skills::watcher::SkillEvent; -use zeph_tools::executor::ToolExecutor; - -use super::Agent; -impl Agent { +impl Agent { #[must_use] pub fn with_stt(mut self, stt: Box) -> Self { self.stt = Some(stt); diff --git a/crates/zeph-core/src/agent/commands.rs b/crates/zeph-core/src/agent/commands.rs deleted file mode 100644 index 6815565..0000000 --- a/crates/zeph-core/src/agent/commands.rs +++ /dev/null @@ -1,198 +0,0 @@ -use zeph_llm::provider::LlmProvider; -use zeph_skills::loader::Skill; -use zeph_skills::matcher::{SkillMatcher, SkillMatcherBackend}; -use zeph_skills::prompt::format_skills_prompt; -use zeph_skills::registry::SkillRegistry; - -use crate::channel::Channel; -use crate::config::Config; -use crate::context::{ContextBudget, build_system_prompt}; -use zeph_tools::executor::ToolExecutor; - -use super::{Agent, error}; - -impl Agent { - pub(super) async fn handle_skills_command(&mut self) -> Result<(), error::AgentError> { - use std::fmt::Write; - - let mut output = String::from("Available skills:\n\n"); - - for meta in self.skill_state.registry.all_meta() { - let trust_info = if let Some(memory) = &self.memory_state.memory { - memory - .sqlite() - .load_skill_trust(&meta.name) - .await - .ok() - .flatten() - .map_or_else(String::new, |r| format!(" [{}]", r.trust_level)) - } else { - String::new() - }; - let _ = writeln!(output, "- {} — {}{trust_info}", meta.name, meta.description); - } - - if let Some(memory) = &self.memory_state.memory { - match memory.sqlite().load_skill_usage().await { - Ok(usage) if !usage.is_empty() => { - output.push_str("\nUsage statistics:\n\n"); - for row in &usage { - let _ = writeln!( - output, - "- {}: {} invocations (last: {})", - row.skill_name, row.invocation_count, row.last_used_at, - ); - } - } - Ok(_) => {} - Err(e) => tracing::warn!("failed to load skill usage: {e:#}"), - } - } - - self.channel.send(&output).await?; - Ok(()) - } - - pub(super) async fn handle_feedback(&mut self, input: &str) -> Result<(), error::AgentError> { - let Some((name, rest)) = input.split_once(' ') else { - self.channel - .send("Usage: /feedback ") - .await?; - return Ok(()); - }; - let (skill_name, feedback) = (name.trim(), rest.trim().trim_matches('"')); - - if feedback.is_empty() { - self.channel - .send("Usage: /feedback ") - .await?; - return Ok(()); - } - - let Some(memory) = &self.memory_state.memory else { - self.channel.send("Memory not available.").await?; - return Ok(()); - }; - - memory - .sqlite() - .record_skill_outcome( - skill_name, - None, - self.memory_state.conversation_id, - "user_rejection", - Some(feedback), - ) - .await?; - - if self.is_learning_enabled() { - self.generate_improved_skill(skill_name, feedback, "", Some(feedback)) - .await - .ok(); - } - - self.channel - .send(&format!("Feedback recorded for \"{skill_name}\".")) - .await?; - Ok(()) - } - - pub(super) async fn reload_skills(&mut self) { - let new_registry = SkillRegistry::load(&self.skill_state.skill_paths); - if new_registry.fingerprint() == self.skill_state.registry.fingerprint() { - return; - } - self.skill_state.registry = new_registry; - - let all_meta = self.skill_state.registry.all_meta(); - let provider = self.provider.clone(); - let embed_fn = |text: &str| -> zeph_skills::matcher::EmbedFuture { - let owned = text.to_owned(); - let p = provider.clone(); - Box::pin(async move { p.embed(&owned).await }) - }; - - let needs_inmemory_rebuild = !self - .skill_state - .matcher - .as_ref() - .is_some_and(SkillMatcherBackend::is_qdrant); - - if needs_inmemory_rebuild { - self.skill_state.matcher = SkillMatcher::new(&all_meta, embed_fn) - .await - .map(SkillMatcherBackend::InMemory); - } else if let Some(ref mut backend) = self.skill_state.matcher - && let Err(e) = backend - .sync(&all_meta, &self.skill_state.embedding_model, embed_fn) - .await - { - tracing::warn!("failed to sync skill embeddings: {e:#}"); - } - - let all_skills: Vec = self - .skill_state - .registry - .all_meta() - .iter() - .filter_map(|m| self.skill_state.registry.get_skill(&m.name).ok()) - .collect(); - let trust_map = self.build_skill_trust_map().await; - let skills_prompt = format_skills_prompt(&all_skills, std::env::consts::OS, &trust_map); - self.skill_state - .last_skills_prompt - .clone_from(&skills_prompt); - let system_prompt = build_system_prompt(&skills_prompt, None, None, false); - if let Some(msg) = self.messages.first_mut() { - msg.content = system_prompt; - } - - tracing::info!( - "reloaded {} skill(s)", - self.skill_state.registry.all_meta().len() - ); - } - - pub(super) fn reload_config(&mut self) { - let Some(ref path) = self.config_path else { - return; - }; - let config = match Config::load(path) { - Ok(c) => c, - Err(e) => { - tracing::warn!("config reload failed: {e:#}"); - return; - } - }; - - self.runtime.security = config.security; - self.runtime.timeouts = config.timeouts; - self.memory_state.history_limit = config.memory.history_limit; - self.memory_state.recall_limit = config.memory.semantic.recall_limit; - self.memory_state.summarization_threshold = config.memory.summarization_threshold; - self.skill_state.max_active_skills = config.skills.max_active_skills; - self.skill_state.disambiguation_threshold = config.skills.disambiguation_threshold; - - if config.memory.context_budget_tokens > 0 { - self.context_state.budget = Some(ContextBudget::new( - config.memory.context_budget_tokens, - 0.20, - )); - } else { - self.context_state.budget = None; - } - self.context_state.compaction_threshold = config.memory.compaction_threshold; - self.context_state.compaction_preserve_tail = config.memory.compaction_preserve_tail; - self.context_state.prune_protect_tokens = config.memory.prune_protect_tokens; - self.memory_state.cross_session_score_threshold = - config.memory.cross_session_score_threshold; - - #[cfg(feature = "index")] - { - self.index.repo_map_ttl = - std::time::Duration::from_secs(config.index.repo_map_ttl_secs); - } - - tracing::info!("config reloaded"); - } -} diff --git a/crates/zeph-core/src/agent/context.rs b/crates/zeph-core/src/agent/context.rs index e181141..62ef088 100644 --- a/crates/zeph-core/src/agent/context.rs +++ b/crates/zeph-core/src/agent/context.rs @@ -8,11 +8,11 @@ use zeph_skills::prompt::format_skills_catalog; use super::{ Agent, CODE_CONTEXT_PREFIX, CROSS_SESSION_PREFIX, Channel, ContextBudget, EnvironmentContext, - LlmProvider, Message, RECALL_PREFIX, Role, SUMMARY_PREFIX, Skill, ToolExecutor, - build_system_prompt, format_skills_prompt, + LlmProvider, Message, RECALL_PREFIX, Role, SUMMARY_PREFIX, Skill, build_system_prompt, + format_skills_prompt, }; -impl Agent { +impl Agent { #[allow( clippy::cast_precision_loss, clippy::cast_possible_truncation, @@ -776,7 +776,7 @@ impl Agent { // Native tool_use: tools are passed via API, skip prompt-based instructions None } else { - let defs = self.tool_executor.tool_definitions(); + let defs = self.tool_executor.tool_definitions_erased(); if defs.is_empty() { None } else { diff --git a/crates/zeph-core/src/agent/index.rs b/crates/zeph-core/src/agent/index.rs index 5ee43fc..e913765 100644 --- a/crates/zeph-core/src/agent/index.rs +++ b/crates/zeph-core/src/agent/index.rs @@ -1,6 +1,6 @@ -use super::{Agent, Channel, ToolExecutor}; +use super::{Agent, Channel}; -impl Agent { +impl Agent { pub(super) async fn fetch_code_rag( index: &super::IndexState, query: &str, diff --git a/crates/zeph-core/src/agent/learning.rs b/crates/zeph-core/src/agent/learning.rs index 589b0b5..dea2e88 100644 --- a/crates/zeph-core/src/agent/learning.rs +++ b/crates/zeph-core/src/agent/learning.rs @@ -1,10 +1,10 @@ -use super::{Agent, Channel, LlmProvider, ToolExecutor}; +use super::{Agent, Channel, LlmProvider}; use super::{LearningConfig, Message, Role, SemanticMemory}; use std::path::PathBuf; -impl Agent { +impl Agent { pub(super) fn is_learning_enabled(&self) -> bool { self.learning_config.as_ref().is_some_and(|c| c.enabled) } diff --git a/crates/zeph-core/src/agent/mcp.rs b/crates/zeph-core/src/agent/mcp.rs index ecdfd31..b475a31 100644 --- a/crates/zeph-core/src/agent/mcp.rs +++ b/crates/zeph-core/src/agent/mcp.rs @@ -1,6 +1,6 @@ -use super::{Agent, Channel, LlmProvider, ToolExecutor}; +use super::{Agent, Channel, LlmProvider}; -impl Agent { +impl Agent { pub(super) async fn handle_mcp_command( &mut self, args: &str, diff --git a/crates/zeph-core/src/agent/message_queue.rs b/crates/zeph-core/src/agent/message_queue.rs index cdd4d55..e837958 100644 --- a/crates/zeph-core/src/agent/message_queue.rs +++ b/crates/zeph-core/src/agent/message_queue.rs @@ -1,7 +1,6 @@ use std::time::{Duration, Instant}; use crate::channel::Channel; -use zeph_tools::executor::ToolExecutor; use super::Agent; @@ -33,7 +32,7 @@ pub(super) fn detect_image_mime(filename: Option<&str>) -> &'static str { } } -impl Agent { +impl Agent { pub(super) fn drain_channel(&mut self) { while self.message_queue.len() < MAX_QUEUE_SIZE { let Some(msg) = self.channel.try_recv() else { @@ -87,141 +86,6 @@ impl Agent { self.message_queue.clear(); count } - - pub(super) async fn resolve_message( - &self, - msg: crate::channel::ChannelMessage, - ) -> (String, Vec) { - use crate::channel::{Attachment, AttachmentKind}; - use zeph_llm::provider::MessagePart; - - let text_base = msg.text.clone(); - - let (audio_attachments, image_attachments): (Vec, Vec) = msg - .attachments - .into_iter() - .partition(|a| a.kind == AttachmentKind::Audio); - - tracing::debug!( - audio = audio_attachments.len(), - has_stt = self.stt.is_some(), - "resolve_message attachments" - ); - - let text = if !audio_attachments.is_empty() - && let Some(stt) = self.stt.as_ref() - { - let mut transcribed_parts = Vec::new(); - for attachment in &audio_attachments { - if attachment.data.len() > MAX_AUDIO_BYTES { - tracing::warn!( - size = attachment.data.len(), - max = MAX_AUDIO_BYTES, - "audio attachment exceeds size limit, skipping" - ); - continue; - } - match stt - .transcribe(&attachment.data, attachment.filename.as_deref()) - .await - { - Ok(result) => { - tracing::info!( - len = result.text.len(), - language = ?result.language, - "audio transcribed" - ); - transcribed_parts.push(result.text); - } - Err(e) => { - tracing::error!(error = %e, "audio transcription failed"); - } - } - } - if transcribed_parts.is_empty() { - text_base - } else { - let transcribed = transcribed_parts.join("\n"); - if text_base.is_empty() { - transcribed - } else { - format!("[transcribed audio]\n{transcribed}\n\n{text_base}") - } - } - } else { - if !audio_attachments.is_empty() { - tracing::warn!( - count = audio_attachments.len(), - "audio attachments received but no STT provider configured, dropping" - ); - } - text_base - }; - - let mut image_parts = Vec::new(); - for attachment in image_attachments { - if attachment.data.len() > MAX_IMAGE_BYTES { - tracing::warn!( - size = attachment.data.len(), - max = MAX_IMAGE_BYTES, - "image attachment exceeds size limit, skipping" - ); - continue; - } - let mime_type = detect_image_mime(attachment.filename.as_deref()).to_string(); - image_parts.push(MessagePart::Image { - data: attachment.data, - mime_type, - }); - } - - (text, image_parts) - } - - pub(super) async fn handle_image_command( - &mut self, - path: &str, - extra_parts: &mut Vec, - ) -> Result<(), super::error::AgentError> { - use std::path::Component; - use zeph_llm::provider::MessagePart; - - // Reject paths that traverse outside the current directory. - let has_parent_dir = std::path::Path::new(path) - .components() - .any(|c| c == Component::ParentDir); - if has_parent_dir { - self.channel - .send("Invalid image path: path traversal not allowed") - .await?; - return Ok(()); - } - - let data = match std::fs::read(path) { - Ok(d) => d, - Err(e) => { - self.channel - .send(&format!("Cannot read image {path}: {e}")) - .await?; - return Ok(()); - } - }; - if data.len() > MAX_IMAGE_BYTES { - self.channel - .send(&format!( - "Image {path} exceeds size limit ({} MB), skipping", - MAX_IMAGE_BYTES / 1024 / 1024 - )) - .await?; - return Ok(()); - } - let mime_type = detect_image_mime(Some(path)).to_string(); - extra_parts.push(MessagePart::Image { data, mime_type }); - self.channel - .send(&format!("Image loaded: {path}. Send your message.")) - .await?; - Ok(()) - } } #[cfg(test)] @@ -378,319 +242,4 @@ mod tests { assert_eq!(agent.message_queue.pop_front().unwrap().text, "msg1"); assert_eq!(agent.message_queue.pop_front().unwrap().text, "msg2"); } - - #[test] - fn detect_image_mime_standard() { - assert_eq!(detect_image_mime(Some("photo.jpg")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("photo.jpeg")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("anim.gif")), "image/gif"); - assert_eq!(detect_image_mime(Some("img.webp")), "image/webp"); - assert_eq!(detect_image_mime(Some("img.png")), "image/png"); - assert_eq!(detect_image_mime(None), "image/png"); - } - - #[test] - fn detect_image_mime_uppercase() { - assert_eq!(detect_image_mime(Some("photo.JPG")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("photo.JPEG")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("anim.GIF")), "image/gif"); - assert_eq!(detect_image_mime(Some("img.WEBP")), "image/webp"); - } - - #[test] - fn detect_image_mime_mixed_case() { - assert_eq!(detect_image_mime(Some("photo.Jpg")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("photo.JpEg")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("anim.Gif")), "image/gif"); - assert_eq!(detect_image_mime(Some("img.WebP")), "image/webp"); - } - - #[test] - fn detect_image_mime_jpeg() { - assert_eq!(detect_image_mime(Some("photo.jpg")), "image/jpeg"); - assert_eq!(detect_image_mime(Some("photo.jpeg")), "image/jpeg"); - } - - #[test] - fn detect_image_mime_gif() { - assert_eq!(detect_image_mime(Some("anim.gif")), "image/gif"); - } - - #[test] - fn detect_image_mime_webp() { - assert_eq!(detect_image_mime(Some("img.webp")), "image/webp"); - } - - #[test] - fn detect_image_mime_unknown_defaults_png() { - assert_eq!(detect_image_mime(Some("file.bmp")), "image/png"); - assert_eq!(detect_image_mime(None), "image/png"); - } - - #[tokio::test] - async fn resolve_message_extracts_image_attachment() { - use crate::channel::{Attachment, AttachmentKind, ChannelMessage}; - let provider = mock_provider(vec![]); - let channel = MockChannel::new(vec![]); - let registry = create_test_registry(); - let executor = MockToolExecutor::no_tools(); - let agent = Agent::new(provider, channel, registry, None, 5, executor); - - let msg = ChannelMessage { - text: "look at this".into(), - attachments: vec![Attachment { - kind: AttachmentKind::Image, - data: vec![0u8; 16], - filename: Some("test.jpg".into()), - }], - }; - let (text, parts) = agent.resolve_message(msg).await; - assert_eq!(text, "look at this"); - assert_eq!(parts.len(), 1); - match &parts[0] { - zeph_llm::provider::MessagePart::Image { mime_type, data } => { - assert_eq!(mime_type, "image/jpeg"); - assert_eq!(data.len(), 16); - } - _ => panic!("expected Image part"), - } - } - - #[tokio::test] - async fn resolve_message_drops_oversized_image() { - use crate::channel::{Attachment, AttachmentKind, ChannelMessage}; - let provider = mock_provider(vec![]); - let channel = MockChannel::new(vec![]); - let registry = create_test_registry(); - let executor = MockToolExecutor::no_tools(); - let agent = Agent::new(provider, channel, registry, None, 5, executor); - - let msg = ChannelMessage { - text: "big image".into(), - attachments: vec![Attachment { - kind: AttachmentKind::Image, - data: vec![0u8; MAX_IMAGE_BYTES + 1], - filename: Some("huge.png".into()), - }], - }; - let (text, parts) = agent.resolve_message(msg).await; - assert_eq!(text, "big image"); - assert!(parts.is_empty()); - } - - #[tokio::test] - async fn handle_image_command_rejects_path_traversal() { - let provider = mock_provider(vec![]); - let channel = MockChannel::new(vec![]); - let registry = create_test_registry(); - let executor = MockToolExecutor::no_tools(); - let mut agent = Agent::new(provider, channel, registry, None, 5, executor); - - let mut parts = Vec::new(); - let result = agent - .handle_image_command("../../etc/passwd", &mut parts) - .await; - assert!(result.is_ok()); - assert!(parts.is_empty()); - let sent = agent.channel.sent_messages(); - assert!(sent.iter().any(|m| m.contains("traversal"))); - } - - #[tokio::test] - async fn handle_image_command_missing_file_sends_error() { - let provider = mock_provider(vec![]); - let channel = MockChannel::new(vec![]); - let registry = create_test_registry(); - let executor = MockToolExecutor::no_tools(); - let mut agent = Agent::new(provider, channel, registry, None, 5, executor); - - let mut parts = Vec::new(); - let result = agent - .handle_image_command("/nonexistent/image.png", &mut parts) - .await; - assert!(result.is_ok()); - assert!(parts.is_empty()); - let sent = agent.channel.sent_messages(); - assert!(sent.iter().any(|m| m.contains("Cannot read image"))); - } - - #[tokio::test] - async fn handle_image_command_loads_valid_file() { - use std::io::Write; - let provider = mock_provider(vec![]); - let channel = MockChannel::new(vec![]); - let registry = create_test_registry(); - let executor = MockToolExecutor::no_tools(); - let mut agent = Agent::new(provider, channel, registry, None, 5, executor); - - let mut tmp = tempfile::NamedTempFile::with_suffix(".jpg").unwrap(); - let data = vec![0xFFu8, 0xD8, 0xFF, 0xE0]; - tmp.write_all(&data).unwrap(); - let path = tmp.path().to_str().unwrap().to_owned(); - - let mut parts = Vec::new(); - let result = agent.handle_image_command(&path, &mut parts).await; - assert!(result.is_ok()); - assert_eq!(parts.len(), 1); - match &parts[0] { - zeph_llm::provider::MessagePart::Image { - data: img_data, - mime_type, - } => { - assert_eq!(img_data, &data); - assert_eq!(mime_type, "image/jpeg"); - } - _ => panic!("expected Image part"), - } - let sent = agent.channel.sent_messages(); - assert!(sent.iter().any(|m| m.contains("Image loaded"))); - } - - mod resolve_message_tests { - use super::super::super::agent_tests::{MockChannel, MockToolExecutor, mock_provider}; - use super::*; - use crate::channel::{Attachment, AttachmentKind, ChannelMessage}; - use std::future::Future; - use std::pin::Pin; - use zeph_llm::error::LlmError; - use zeph_llm::stt::{SpeechToText, Transcription}; - - struct MockStt { - text: Option, - } - - impl MockStt { - fn ok(text: &str) -> Self { - Self { - text: Some(text.to_string()), - } - } - - fn failing() -> Self { - Self { text: None } - } - } - - impl SpeechToText for MockStt { - fn transcribe( - &self, - _audio: &[u8], - _filename: Option<&str>, - ) -> Pin> + Send + '_>> - { - let result = match &self.text { - Some(t) => Ok(Transcription { - text: t.clone(), - language: None, - duration_secs: None, - }), - None => Err(LlmError::TranscriptionFailed("mock error".into())), - }; - Box::pin(async move { result }) - } - } - - fn make_agent(stt: Option>) -> Agent { - let provider = mock_provider(vec!["ok".into()]); - let empty: Vec = vec![]; - let registry = zeph_skills::registry::SkillRegistry::load(&empty); - let channel = MockChannel::new(vec![]); - let executor = MockToolExecutor::no_tools(); - let mut agent = Agent::new(provider, channel, registry, None, 5, executor); - agent.stt = stt; - agent - } - - fn audio_attachment(data: &[u8]) -> Attachment { - Attachment { - kind: AttachmentKind::Audio, - data: data.to_vec(), - filename: Some("test.wav".into()), - } - } - - #[tokio::test] - async fn no_audio_attachments_returns_text() { - let agent = make_agent(None); - let msg = ChannelMessage { - text: "hello".into(), - attachments: vec![], - }; - assert_eq!(agent.resolve_message(msg).await.0, "hello"); - } - - #[tokio::test] - async fn audio_without_stt_returns_original_text() { - let agent = make_agent(None); - let msg = ChannelMessage { - text: "hello".into(), - attachments: vec![audio_attachment(b"audio-data")], - }; - assert_eq!(agent.resolve_message(msg).await.0, "hello"); - } - - #[tokio::test] - async fn audio_with_stt_prepends_transcription() { - let agent = make_agent(Some(Box::new(MockStt::ok("transcribed text")))); - let msg = ChannelMessage { - text: "original".into(), - attachments: vec![audio_attachment(b"audio-data")], - }; - let (result, _) = agent.resolve_message(msg).await; - assert!(result.contains("[transcribed audio]")); - assert!(result.contains("transcribed text")); - assert!(result.contains("original")); - } - - #[tokio::test] - async fn audio_with_stt_no_original_text() { - let agent = make_agent(Some(Box::new(MockStt::ok("transcribed text")))); - let msg = ChannelMessage { - text: String::new(), - attachments: vec![audio_attachment(b"audio-data")], - }; - let (result, _) = agent.resolve_message(msg).await; - assert_eq!(result, "transcribed text"); - } - - #[tokio::test] - async fn all_transcriptions_fail_returns_original() { - let agent = make_agent(Some(Box::new(MockStt::failing()))); - let msg = ChannelMessage { - text: "original".into(), - attachments: vec![audio_attachment(b"audio-data")], - }; - assert_eq!(agent.resolve_message(msg).await.0, "original"); - } - - #[tokio::test] - async fn multiple_audio_attachments_joined() { - let agent = make_agent(Some(Box::new(MockStt::ok("chunk")))); - let msg = ChannelMessage { - text: String::new(), - attachments: vec![ - audio_attachment(b"a1"), - audio_attachment(b"a2"), - audio_attachment(b"a3"), - ], - }; - let (result, _) = agent.resolve_message(msg).await; - assert_eq!(result, "chunk\nchunk\nchunk"); - } - - #[tokio::test] - async fn oversized_audio_skipped() { - let agent = make_agent(Some(Box::new(MockStt::ok("should not appear")))); - let big = vec![0u8; MAX_AUDIO_BYTES + 1]; - let msg = ChannelMessage { - text: "original".into(), - attachments: vec![Attachment { - kind: AttachmentKind::Audio, - data: big, - filename: None, - }], - }; - assert_eq!(agent.resolve_message(msg).await.0, "original"); - } - } } diff --git a/crates/zeph-core/src/agent/mod.rs b/crates/zeph-core/src/agent/mod.rs index fe11c98..8014669 100644 --- a/crates/zeph-core/src/agent/mod.rs +++ b/crates/zeph-core/src/agent/mod.rs @@ -1,5 +1,4 @@ mod builder; -mod commands; mod context; pub mod error; #[cfg(feature = "index")] @@ -28,20 +27,21 @@ use crate::metrics::MetricsSnapshot; use std::collections::HashMap; use zeph_memory::semantic::SemanticMemory; use zeph_skills::loader::Skill; -use zeph_skills::matcher::SkillMatcherBackend; +use zeph_skills::matcher::{SkillMatcher, SkillMatcherBackend}; use zeph_skills::prompt::format_skills_prompt; use zeph_skills::registry::SkillRegistry; use zeph_skills::watcher::SkillEvent; -use zeph_tools::executor::ToolExecutor; +use zeph_tools::executor::{ErasedToolExecutor, ToolExecutor}; use crate::channel::Channel; +use crate::config::Config; use crate::config::LearningConfig; use crate::config::{SecurityConfig, TimeoutConfig}; use crate::config_watcher::ConfigEvent; use crate::context::{ContextBudget, EnvironmentContext, build_system_prompt}; use crate::cost::CostTracker; -use message_queue::QueuedMessage; +use message_queue::{MAX_AUDIO_BYTES, MAX_IMAGE_BYTES, QueuedMessage, detect_image_mime}; pub(crate) const DOOM_LOOP_WINDOW: usize = 3; const TOOL_LOOP_KEEP_RECENT: usize = 4; @@ -119,10 +119,10 @@ pub(super) struct RuntimeConfig { pub(super) permission_policy: zeph_tools::PermissionPolicy, } -pub struct Agent { +pub struct Agent { provider: AnyProvider, channel: C, - tool_executor: T, + tool_executor: Box, messages: Vec, pub(super) memory_state: MemoryState, pub(super) skill_state: SkillState, @@ -150,7 +150,7 @@ pub struct Agent { update_notify_rx: Option>, } -impl Agent { +impl Agent { #[must_use] pub fn new( provider: AnyProvider, @@ -158,7 +158,7 @@ impl Agent { registry: SkillRegistry, matcher: Option, max_active_skills: usize, - tool_executor: T, + tool_executor: impl ToolExecutor + 'static, ) -> Self { let all_skills: Vec = registry .all_meta() @@ -176,7 +176,7 @@ impl Agent { Self { provider, channel, - tool_executor, + tool_executor: Box::new(tool_executor), messages: vec![Message { role: Role::System, content: system_prompt, @@ -350,6 +350,96 @@ impl Agent { Ok(()) } + async fn resolve_message( + &self, + msg: crate::channel::ChannelMessage, + ) -> (String, Vec) { + use crate::channel::{Attachment, AttachmentKind}; + use zeph_llm::provider::{ImageData, MessagePart}; + + let text_base = msg.text.clone(); + + let (audio_attachments, image_attachments): (Vec, Vec) = msg + .attachments + .into_iter() + .partition(|a| a.kind == AttachmentKind::Audio); + + tracing::debug!( + audio = audio_attachments.len(), + has_stt = self.stt.is_some(), + "resolve_message attachments" + ); + + let text = if !audio_attachments.is_empty() + && let Some(stt) = self.stt.as_ref() + { + let mut transcribed_parts = Vec::new(); + for attachment in &audio_attachments { + if attachment.data.len() > MAX_AUDIO_BYTES { + tracing::warn!( + size = attachment.data.len(), + max = MAX_AUDIO_BYTES, + "audio attachment exceeds size limit, skipping" + ); + continue; + } + match stt + .transcribe(&attachment.data, attachment.filename.as_deref()) + .await + { + Ok(result) => { + tracing::info!( + len = result.text.len(), + language = ?result.language, + "audio transcribed" + ); + transcribed_parts.push(result.text); + } + Err(e) => { + tracing::error!(error = %e, "audio transcription failed"); + } + } + } + if transcribed_parts.is_empty() { + text_base + } else { + let transcribed = transcribed_parts.join("\n"); + if text_base.is_empty() { + transcribed + } else { + format!("[transcribed audio]\n{transcribed}\n\n{text_base}") + } + } + } else { + if !audio_attachments.is_empty() { + tracing::warn!( + count = audio_attachments.len(), + "audio attachments received but no STT provider configured, dropping" + ); + } + text_base + }; + + let mut image_parts = Vec::new(); + for attachment in image_attachments { + if attachment.data.len() > MAX_IMAGE_BYTES { + tracing::warn!( + size = attachment.data.len(), + max = MAX_IMAGE_BYTES, + "image attachment exceeds size limit, skipping" + ); + continue; + } + let mime_type = detect_image_mime(attachment.filename.as_deref()).to_string(); + image_parts.push(MessagePart::Image(Box::new(ImageData { + data: attachment.data, + mime_type, + }))); + } + + (text, image_parts) + } + async fn process_user_message( &mut self, text: String, @@ -433,6 +523,235 @@ impl Agent { Ok(()) } + + async fn handle_image_command( + &mut self, + path: &str, + extra_parts: &mut Vec, + ) -> Result<(), error::AgentError> { + use std::path::Component; + use zeph_llm::provider::{ImageData, MessagePart}; + + // Reject paths that traverse outside the current directory. + let has_parent_dir = std::path::Path::new(path) + .components() + .any(|c| c == Component::ParentDir); + if has_parent_dir { + self.channel + .send("Invalid image path: path traversal not allowed") + .await?; + return Ok(()); + } + + let data = match std::fs::read(path) { + Ok(d) => d, + Err(e) => { + self.channel + .send(&format!("Cannot read image {path}: {e}")) + .await?; + return Ok(()); + } + }; + if data.len() > MAX_IMAGE_BYTES { + self.channel + .send(&format!( + "Image {path} exceeds size limit ({} MB), skipping", + MAX_IMAGE_BYTES / 1024 / 1024 + )) + .await?; + return Ok(()); + } + let mime_type = detect_image_mime(Some(path)).to_string(); + extra_parts.push(MessagePart::Image(Box::new(ImageData { data, mime_type }))); + self.channel + .send(&format!("Image loaded: {path}. Send your message.")) + .await?; + Ok(()) + } + + async fn handle_skills_command(&mut self) -> Result<(), error::AgentError> { + use std::fmt::Write; + + let mut output = String::from("Available skills:\n\n"); + + for meta in self.skill_state.registry.all_meta() { + let trust_info = if let Some(memory) = &self.memory_state.memory { + memory + .sqlite() + .load_skill_trust(&meta.name) + .await + .ok() + .flatten() + .map_or_else(String::new, |r| format!(" [{}]", r.trust_level)) + } else { + String::new() + }; + let _ = writeln!(output, "- {} — {}{trust_info}", meta.name, meta.description); + } + + if let Some(memory) = &self.memory_state.memory { + match memory.sqlite().load_skill_usage().await { + Ok(usage) if !usage.is_empty() => { + output.push_str("\nUsage statistics:\n\n"); + for row in &usage { + let _ = writeln!( + output, + "- {}: {} invocations (last: {})", + row.skill_name, row.invocation_count, row.last_used_at, + ); + } + } + Ok(_) => {} + Err(e) => tracing::warn!("failed to load skill usage: {e:#}"), + } + } + + self.channel.send(&output).await?; + Ok(()) + } + + async fn handle_feedback(&mut self, input: &str) -> Result<(), error::AgentError> { + let Some((name, rest)) = input.split_once(' ') else { + self.channel + .send("Usage: /feedback ") + .await?; + return Ok(()); + }; + let (skill_name, feedback) = (name.trim(), rest.trim().trim_matches('"')); + + if feedback.is_empty() { + self.channel + .send("Usage: /feedback ") + .await?; + return Ok(()); + } + + let Some(memory) = &self.memory_state.memory else { + self.channel.send("Memory not available.").await?; + return Ok(()); + }; + + memory + .sqlite() + .record_skill_outcome( + skill_name, + None, + self.memory_state.conversation_id, + "user_rejection", + Some(feedback), + ) + .await?; + + if self.is_learning_enabled() { + self.generate_improved_skill(skill_name, feedback, "", Some(feedback)) + .await + .ok(); + } + + self.channel + .send(&format!("Feedback recorded for \"{skill_name}\".")) + .await?; + Ok(()) + } + + async fn reload_skills(&mut self) { + let new_registry = SkillRegistry::load(&self.skill_state.skill_paths); + if new_registry.fingerprint() == self.skill_state.registry.fingerprint() { + return; + } + self.skill_state.registry = new_registry; + + let all_meta = self.skill_state.registry.all_meta(); + let provider = self.provider.clone(); + let embed_fn = |text: &str| -> zeph_skills::matcher::EmbedFuture { + let owned = text.to_owned(); + let p = provider.clone(); + Box::pin(async move { p.embed(&owned).await }) + }; + + let needs_inmemory_rebuild = !self + .skill_state + .matcher + .as_ref() + .is_some_and(SkillMatcherBackend::is_qdrant); + + if needs_inmemory_rebuild { + self.skill_state.matcher = SkillMatcher::new(&all_meta, embed_fn) + .await + .map(SkillMatcherBackend::InMemory); + } else if let Some(ref mut backend) = self.skill_state.matcher + && let Err(e) = backend + .sync(&all_meta, &self.skill_state.embedding_model, embed_fn) + .await + { + tracing::warn!("failed to sync skill embeddings: {e:#}"); + } + + let all_skills: Vec = self + .skill_state + .registry + .all_meta() + .iter() + .filter_map(|m| self.skill_state.registry.get_skill(&m.name).ok()) + .collect(); + let trust_map = self.build_skill_trust_map().await; + let skills_prompt = format_skills_prompt(&all_skills, std::env::consts::OS, &trust_map); + self.skill_state + .last_skills_prompt + .clone_from(&skills_prompt); + let system_prompt = build_system_prompt(&skills_prompt, None, None, false); + if let Some(msg) = self.messages.first_mut() { + msg.content = system_prompt; + } + + tracing::info!( + "reloaded {} skill(s)", + self.skill_state.registry.all_meta().len() + ); + } + + fn reload_config(&mut self) { + let Some(ref path) = self.config_path else { + return; + }; + let config = match Config::load(path) { + Ok(c) => c, + Err(e) => { + tracing::warn!("config reload failed: {e:#}"); + return; + } + }; + + self.runtime.security = config.security; + self.runtime.timeouts = config.timeouts; + self.memory_state.history_limit = config.memory.history_limit; + self.memory_state.recall_limit = config.memory.semantic.recall_limit; + self.memory_state.summarization_threshold = config.memory.summarization_threshold; + self.skill_state.max_active_skills = config.skills.max_active_skills; + self.skill_state.disambiguation_threshold = config.skills.disambiguation_threshold; + + if config.memory.context_budget_tokens > 0 { + self.context_state.budget = Some(ContextBudget::new( + config.memory.context_budget_tokens, + 0.20, + )); + } else { + self.context_state.budget = None; + } + self.context_state.compaction_threshold = config.memory.compaction_threshold; + self.context_state.compaction_preserve_tail = config.memory.compaction_preserve_tail; + self.context_state.prune_protect_tokens = config.memory.prune_protect_tokens; + self.memory_state.cross_session_score_threshold = + config.memory.cross_session_score_threshold; + + #[cfg(feature = "index")] + { + self.index.repo_map_ttl = + std::time::Duration::from_secs(config.index.repo_map_ttl_secs); + } + + tracing::info!("config reloaded"); + } } pub(crate) async fn shutdown_signal(rx: &mut watch::Receiver) { while !*rx.borrow_and_update() { @@ -458,13 +777,14 @@ pub(crate) async fn recv_optional(rx: &mut Option>) -> Opti #[cfg(test)] pub(super) mod agent_tests { + use super::message_queue::{MAX_AUDIO_BYTES, MAX_IMAGE_BYTES, detect_image_mime}; #[allow(unused_imports)] pub(crate) use super::{ Agent, CODE_CONTEXT_PREFIX, CROSS_SESSION_PREFIX, DOOM_LOOP_WINDOW, RECALL_PREFIX, SUMMARY_PREFIX, TOOL_OUTPUT_SUFFIX, format_tool_output, recv_optional, shutdown_signal, }; pub(crate) use crate::channel::Channel; - use crate::channel::ChannelMessage; + use crate::channel::{Attachment, AttachmentKind, ChannelMessage}; pub(crate) use crate::config::{SecurityConfig, TimeoutConfig}; pub(crate) use crate::metrics::MetricsSnapshot; use std::sync::{Arc, Mutex}; @@ -1462,4 +1782,289 @@ pub(super) mod agent_tests { tokio::time::sleep(std::time::Duration::from_millis(20)).await; assert!(token2.is_cancelled()); } + + mod resolve_message_tests { + use super::*; + use crate::channel::{Attachment, AttachmentKind, ChannelMessage}; + use std::future::Future; + use std::pin::Pin; + use zeph_llm::error::LlmError; + use zeph_llm::stt::{SpeechToText, Transcription}; + + struct MockStt { + text: Option, + } + + impl MockStt { + fn ok(text: &str) -> Self { + Self { + text: Some(text.to_string()), + } + } + + fn failing() -> Self { + Self { text: None } + } + } + + impl SpeechToText for MockStt { + fn transcribe( + &self, + _audio: &[u8], + _filename: Option<&str>, + ) -> Pin> + Send + '_>> + { + let result = match &self.text { + Some(t) => Ok(Transcription { + text: t.clone(), + language: None, + duration_secs: None, + }), + None => Err(LlmError::TranscriptionFailed("mock error".into())), + }; + Box::pin(async move { result }) + } + } + + fn make_agent(stt: Option>) -> Agent { + let provider = mock_provider(vec!["ok".into()]); + let empty: Vec = vec![]; + let registry = zeph_skills::registry::SkillRegistry::load(&empty); + let channel = MockChannel::new(vec![]); + let executor = MockToolExecutor::no_tools(); + let mut agent = Agent::new(provider, channel, registry, None, 5, executor); + agent.stt = stt; + agent + } + + fn audio_attachment(data: &[u8]) -> Attachment { + Attachment { + kind: AttachmentKind::Audio, + data: data.to_vec(), + filename: Some("test.wav".into()), + } + } + + #[tokio::test] + async fn no_audio_attachments_returns_text() { + let agent = make_agent(None); + let msg = ChannelMessage { + text: "hello".into(), + attachments: vec![], + }; + assert_eq!(agent.resolve_message(msg).await.0, "hello"); + } + + #[tokio::test] + async fn audio_without_stt_returns_original_text() { + let agent = make_agent(None); + let msg = ChannelMessage { + text: "hello".into(), + attachments: vec![audio_attachment(b"audio-data")], + }; + assert_eq!(agent.resolve_message(msg).await.0, "hello"); + } + + #[tokio::test] + async fn audio_with_stt_prepends_transcription() { + let agent = make_agent(Some(Box::new(MockStt::ok("transcribed text")))); + let msg = ChannelMessage { + text: "original".into(), + attachments: vec![audio_attachment(b"audio-data")], + }; + let (result, _) = agent.resolve_message(msg).await; + assert!(result.contains("[transcribed audio]")); + assert!(result.contains("transcribed text")); + assert!(result.contains("original")); + } + + #[tokio::test] + async fn audio_with_stt_no_original_text() { + let agent = make_agent(Some(Box::new(MockStt::ok("transcribed text")))); + let msg = ChannelMessage { + text: String::new(), + attachments: vec![audio_attachment(b"audio-data")], + }; + let (result, _) = agent.resolve_message(msg).await; + assert_eq!(result, "transcribed text"); + } + + #[tokio::test] + async fn all_transcriptions_fail_returns_original() { + let agent = make_agent(Some(Box::new(MockStt::failing()))); + let msg = ChannelMessage { + text: "original".into(), + attachments: vec![audio_attachment(b"audio-data")], + }; + assert_eq!(agent.resolve_message(msg).await.0, "original"); + } + + #[tokio::test] + async fn multiple_audio_attachments_joined() { + let agent = make_agent(Some(Box::new(MockStt::ok("chunk")))); + let msg = ChannelMessage { + text: String::new(), + attachments: vec![ + audio_attachment(b"a1"), + audio_attachment(b"a2"), + audio_attachment(b"a3"), + ], + }; + let (result, _) = agent.resolve_message(msg).await; + assert_eq!(result, "chunk\nchunk\nchunk"); + } + + #[tokio::test] + async fn oversized_audio_skipped() { + let agent = make_agent(Some(Box::new(MockStt::ok("should not appear")))); + let big = vec![0u8; MAX_AUDIO_BYTES + 1]; + let msg = ChannelMessage { + text: "original".into(), + attachments: vec![Attachment { + kind: AttachmentKind::Audio, + data: big, + filename: None, + }], + }; + assert_eq!(agent.resolve_message(msg).await.0, "original"); + } + } + + #[test] + fn detect_image_mime_jpeg() { + assert_eq!(detect_image_mime(Some("photo.jpg")), "image/jpeg"); + assert_eq!(detect_image_mime(Some("photo.jpeg")), "image/jpeg"); + } + + #[test] + fn detect_image_mime_gif() { + assert_eq!(detect_image_mime(Some("anim.gif")), "image/gif"); + } + + #[test] + fn detect_image_mime_webp() { + assert_eq!(detect_image_mime(Some("img.webp")), "image/webp"); + } + + #[test] + fn detect_image_mime_unknown_defaults_png() { + assert_eq!(detect_image_mime(Some("file.bmp")), "image/png"); + assert_eq!(detect_image_mime(None), "image/png"); + } + + #[tokio::test] + async fn resolve_message_extracts_image_attachment() { + let provider = mock_provider(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let msg = ChannelMessage { + text: "look at this".into(), + attachments: vec![Attachment { + kind: AttachmentKind::Image, + data: vec![0u8; 16], + filename: Some("test.jpg".into()), + }], + }; + let (text, parts) = agent.resolve_message(msg).await; + assert_eq!(text, "look at this"); + assert_eq!(parts.len(), 1); + match &parts[0] { + zeph_llm::provider::MessagePart::Image(img) => { + assert_eq!(img.mime_type, "image/jpeg"); + assert_eq!(img.data.len(), 16); + } + _ => panic!("expected Image part"), + } + } + + #[tokio::test] + async fn resolve_message_drops_oversized_image() { + let provider = mock_provider(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let msg = ChannelMessage { + text: "big image".into(), + attachments: vec![Attachment { + kind: AttachmentKind::Image, + data: vec![0u8; MAX_IMAGE_BYTES + 1], + filename: Some("huge.png".into()), + }], + }; + let (text, parts) = agent.resolve_message(msg).await; + assert_eq!(text, "big image"); + assert!(parts.is_empty()); + } + + #[tokio::test] + async fn handle_image_command_rejects_path_traversal() { + let provider = mock_provider(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + let mut agent = Agent::new(provider, channel, registry, None, 5, executor); + + let mut parts = Vec::new(); + let result = agent + .handle_image_command("../../etc/passwd", &mut parts) + .await; + assert!(result.is_ok()); + assert!(parts.is_empty()); + // Channel should have received an error message + let sent = agent.channel.sent_messages(); + assert!(sent.iter().any(|m| m.contains("traversal"))); + } + + #[tokio::test] + async fn handle_image_command_missing_file_sends_error() { + let provider = mock_provider(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + let mut agent = Agent::new(provider, channel, registry, None, 5, executor); + + let mut parts = Vec::new(); + let result = agent + .handle_image_command("/nonexistent/image.png", &mut parts) + .await; + assert!(result.is_ok()); + assert!(parts.is_empty()); + let sent = agent.channel.sent_messages(); + assert!(sent.iter().any(|m| m.contains("Cannot read image"))); + } + + #[tokio::test] + async fn handle_image_command_loads_valid_file() { + use std::io::Write; + let provider = mock_provider(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + let mut agent = Agent::new(provider, channel, registry, None, 5, executor); + + // Write a small temp image + let mut tmp = tempfile::NamedTempFile::with_suffix(".jpg").unwrap(); + let data = vec![0xFFu8, 0xD8, 0xFF, 0xE0]; + tmp.write_all(&data).unwrap(); + let path = tmp.path().to_str().unwrap().to_owned(); + + let mut parts = Vec::new(); + let result = agent.handle_image_command(&path, &mut parts).await; + assert!(result.is_ok()); + assert_eq!(parts.len(), 1); + match &parts[0] { + zeph_llm::provider::MessagePart::Image(img) => { + assert_eq!(img.data, data); + assert_eq!(img.mime_type, "image/jpeg"); + } + _ => panic!("expected Image part"), + } + let sent = agent.channel.sent_messages(); + assert!(sent.iter().any(|m| m.contains("Image loaded"))); + } } diff --git a/crates/zeph-core/src/agent/persistence.rs b/crates/zeph-core/src/agent/persistence.rs index 678f016..f293b6e 100644 --- a/crates/zeph-core/src/agent/persistence.rs +++ b/crates/zeph-core/src/agent/persistence.rs @@ -1,12 +1,10 @@ +use crate::channel::Channel; use zeph_llm::provider::Role; use zeph_memory::sqlite::role_str; -use zeph_tools::executor::ToolExecutor; - -use crate::channel::Channel; use super::Agent; -impl Agent { +impl Agent { /// Load conversation history from memory and inject into messages. /// /// # Errors diff --git a/crates/zeph-core/src/agent/tool_execution.rs b/crates/zeph-core/src/agent/tool_execution.rs index 7895819..067f4a0 100644 --- a/crates/zeph-core/src/agent/tool_execution.rs +++ b/crates/zeph-core/src/agent/tool_execution.rs @@ -1,6 +1,6 @@ use tokio_stream::StreamExt; use zeph_llm::provider::{ChatResponse, LlmProvider, Message, MessagePart, Role, ToolDefinition}; -use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput}; +use zeph_tools::executor::{ToolCall, ToolError, ToolOutput}; use super::{Agent, DOOM_LOOP_WINDOW, TOOL_LOOP_KEEP_RECENT, format_tool_output}; use crate::channel::Channel; @@ -68,8 +68,8 @@ fn handle_tool_use(out: &mut String, rest: &mut &str, start: usize) { } } -impl Agent { - pub(super) async fn process_response(&mut self) -> Result<(), super::error::AgentError> { +impl Agent { + pub(crate) async fn process_response(&mut self) -> Result<(), super::error::AgentError> { if self.provider.supports_tool_use() { tracing::debug!( provider = self.provider.name(), @@ -141,7 +141,7 @@ impl Agent { let result = self .tool_executor - .execute(&response) + .execute_erased(&response) .instrument(tracing::info_span!("tool_exec")) .await; if !self.handle_tool_result(&response, result).await? { @@ -427,7 +427,9 @@ impl Agent { Err(ToolError::ConfirmationRequired { command }) => { let prompt = format!("Allow command: {command}?"); if self.channel.confirm(&prompt).await? { - if let Ok(Some(out)) = self.tool_executor.execute_confirmed(response).await { + if let Ok(Some(out)) = + self.tool_executor.execute_confirmed_erased(response).await + { let processed = self.maybe_summarize_tool_output(&out.summary).await; let formatted = format_tool_output(&out.tool_name, &processed); let display = self.maybe_redact(&formatted); @@ -533,7 +535,7 @@ impl Agent { let tool_defs: Vec = self .tool_executor - .tool_definitions() + .tool_definitions_erased() .iter() .map(tool_def_to_definition) .collect(); @@ -741,7 +743,7 @@ impl Agent { .iter() .zip(tool_calls.iter()) .map(|(call, tc)| { - self.tool_executor.execute_tool_call(call).instrument( + self.tool_executor.execute_tool_call_erased(call).instrument( tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), ) }) @@ -751,7 +753,7 @@ impl Agent { use futures::StreamExt; let stream = futures::stream::iter(calls.iter().zip(tool_calls.iter()).map(|(call, tc)| { - self.tool_executor.execute_tool_call(call).instrument( + self.tool_executor.execute_tool_call_erased(call).instrument( tracing::info_span!("tool_exec", tool_name = %tc.name, idx = %tc.id), ) })); diff --git a/crates/zeph-core/src/agent/trust_commands.rs b/crates/zeph-core/src/agent/trust_commands.rs index 3f0d3ae..e53915b 100644 --- a/crates/zeph-core/src/agent/trust_commands.rs +++ b/crates/zeph-core/src/agent/trust_commands.rs @@ -3,9 +3,9 @@ use std::fmt::Write; use zeph_skills::TrustLevel; -use super::{Agent, Channel, ToolExecutor}; +use super::{Agent, Channel}; -impl Agent { +impl Agent { /// Handle `/skill trust [name [level]]`. pub(super) async fn handle_skill_trust_command( &mut self, diff --git a/crates/zeph-core/src/agent/utils.rs b/crates/zeph-core/src/agent/utils.rs index b57e5ed..b5d5a3f 100644 --- a/crates/zeph-core/src/agent/utils.rs +++ b/crates/zeph-core/src/agent/utils.rs @@ -1,12 +1,10 @@ use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role}; +use super::{Agent, CODE_CONTEXT_PREFIX}; use crate::channel::Channel; use crate::metrics::MetricsSnapshot; -use zeph_tools::executor::ToolExecutor; - -use super::{Agent, CODE_CONTEXT_PREFIX}; -impl Agent { +impl Agent { pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) { if let Some(ref tx) = self.metrics_tx { let elapsed = self.start_time.elapsed().as_secs(); diff --git a/crates/zeph-llm/src/claude.rs b/crates/zeph-llm/src/claude.rs index 93889e2..2f946ff 100644 --- a/crates/zeph-llm/src/claude.rs +++ b/crates/zeph-llm/src/claude.rs @@ -90,11 +90,9 @@ impl ClaudeProvider { } fn has_image_parts(messages: &[Message]) -> bool { - messages.iter().any(|m| { - m.parts - .iter() - .any(|p| matches!(p, MessagePart::Image { .. })) - }) + messages + .iter() + .any(|m| m.parts.iter().any(|p| matches!(p, MessagePart::Image(_)))) } fn build_request(&self, messages: &[Message], stream: bool) -> reqwest::RequestBuilder { @@ -644,7 +642,7 @@ fn split_messages_structured(messages: &[Message]) -> (Option, Vec (Option, Vec { + MessagePart::Image(img) => { blocks.push(AnthropicContentBlock::Image { source: ImageSource { source_type: "base64".to_owned(), - media_type: mime_type.clone(), - data: STANDARD.encode(data), + media_type: img.mime_type.clone(), + data: STANDARD.encode(&img.data), }, }); } @@ -787,6 +785,7 @@ struct ContentBlock { #[cfg(test)] mod tests { use super::*; + use crate::provider::ImageData; use tokio_stream::StreamExt; #[test] @@ -1444,10 +1443,10 @@ mod tests { MessagePart::Text { text: "look at this".into(), }, - MessagePart::Image { + MessagePart::Image(Box::new(ImageData { data: data.clone(), mime_type: "image/jpeg".into(), - }, + })), ], ); let (system, chat) = split_messages_structured(&[msg]); @@ -1478,10 +1477,10 @@ mod tests { fn has_image_parts_detects_image_in_messages() { let with_image = Message::from_parts( Role::User, - vec![MessagePart::Image { + vec![MessagePart::Image(Box::new(ImageData { data: vec![1], mime_type: "image/png".into(), - }], + }))], ); let without_image = Message::from_legacy(Role::User, "plain text"); assert!(ClaudeProvider::has_image_parts(&[with_image])); diff --git a/crates/zeph-llm/src/ollama.rs b/crates/zeph-llm/src/ollama.rs index d44d795..eff824d 100644 --- a/crates/zeph-llm/src/ollama.rs +++ b/crates/zeph-llm/src/ollama.rs @@ -116,11 +116,9 @@ impl LlmProvider for OllamaProvider { } async fn chat(&self, messages: &[Message]) -> Result { - let has_images = messages.iter().any(|m| { - m.parts - .iter() - .any(|p| matches!(p, MessagePart::Image { .. })) - }); + let has_images = messages + .iter() + .any(|m| m.parts.iter().any(|p| matches!(p, MessagePart::Image(_)))); let model = if has_images { self.vision_model.as_deref().unwrap_or(&self.model) } else { @@ -140,11 +138,9 @@ impl LlmProvider for OllamaProvider { } async fn chat_stream(&self, messages: &[Message]) -> Result { - let has_images = messages.iter().any(|m| { - m.parts - .iter() - .any(|p| matches!(p, MessagePart::Image { .. })) - }); + let has_images = messages + .iter() + .any(|m| m.parts.iter().any(|p| matches!(p, MessagePart::Image(_)))); let model = if has_images { self.vision_model.as_deref().unwrap_or(&self.model) } else { @@ -207,9 +203,7 @@ fn convert_message(msg: &Message) -> ChatMessage { .parts .iter() .filter_map(|p| match p { - MessagePart::Image { data, .. } => { - Some(OllamaImage::from_base64(STANDARD.encode(data))) - } + MessagePart::Image(img) => Some(OllamaImage::from_base64(STANDARD.encode(&img.data))), _ => None, }) .collect(); @@ -256,6 +250,7 @@ fn parse_host_port(url: &str) -> (String, u16) { #[cfg(test)] mod tests { use super::*; + use crate::provider::ImageData; #[test] fn context_window_none_by_default() { @@ -656,10 +651,10 @@ mod tests { MessagePart::Text { text: "describe".into(), }, - MessagePart::Image { + MessagePart::Image(Box::new(ImageData { data: data.clone(), mime_type: "image/jpeg".into(), - }, + })), ], ); let chat_msg = convert_message(&msg); diff --git a/crates/zeph-llm/src/openai.rs b/crates/zeph-llm/src/openai.rs index 436226c..4b3e42f 100644 --- a/crates/zeph-llm/src/openai.rs +++ b/crates/zeph-llm/src/openai.rs @@ -523,11 +523,9 @@ struct VisionChatRequest<'a> { } fn has_image_parts(messages: &[Message]) -> bool { - messages.iter().any(|m| { - m.parts - .iter() - .any(|p| matches!(p, MessagePart::Image { .. })) - }) + messages + .iter() + .any(|m| m.parts.iter().any(|p| matches!(p, MessagePart::Image(_)))) } fn convert_messages_vision(messages: &[Message]) -> Vec { @@ -539,10 +537,7 @@ fn convert_messages_vision(messages: &[Message]) -> Vec { Role::User => "user", Role::Assistant => "assistant", }; - let has_images = msg - .parts - .iter() - .any(|p| matches!(p, MessagePart::Image { .. })); + let has_images = msg.parts.iter().any(|p| matches!(p, MessagePart::Image(_))); if has_images { let mut parts = Vec::new(); let text_str: String = msg @@ -562,11 +557,11 @@ fn convert_messages_vision(messages: &[Message]) -> Vec { parts.push(OpenAiContentPart::Text { text: text_str }); } for part in &msg.parts { - if let MessagePart::Image { data, mime_type } = part { - let b64 = STANDARD.encode(data); + if let MessagePart::Image(img) = part { + let b64 = STANDARD.encode(&img.data); parts.push(OpenAiContentPart::ImageUrl { image_url: ImageUrlDetail { - url: format!("data:{mime_type};base64,{b64}"), + url: format!("data:{};base64,{b64}", img.mime_type), }, }); } @@ -878,6 +873,7 @@ struct JsonSchemaFormat<'a> { #[cfg(test)] mod tests { use super::*; + use crate::provider::ImageData; use tokio_stream::StreamExt; fn test_provider() -> OpenAiProvider { @@ -1506,10 +1502,10 @@ mod tests { MessagePart::Text { text: "look".into(), }, - MessagePart::Image { + MessagePart::Image(Box::new(ImageData { data: vec![1, 2, 3], mime_type: "image/png".into(), - }, + })), ], ); let msg_text_only = Message::from_legacy(Role::User, "plain"); @@ -1527,10 +1523,10 @@ mod tests { MessagePart::Text { text: "describe this".into(), }, - MessagePart::Image { + MessagePart::Image(Box::new(ImageData { data: data.clone(), mime_type: "image/jpeg".into(), - }, + })), ], ); let converted = convert_messages_vision(&[msg]); @@ -1569,10 +1565,10 @@ mod tests { fn convert_messages_vision_image_only_no_text_part() { let msg = Message::from_parts( Role::User, - vec![MessagePart::Image { + vec![MessagePart::Image(Box::new(ImageData { data: vec![1], mime_type: "image/png".into(), - }], + }))], ); let converted = convert_messages_vision(&[msg]); // No text parts collected → only image_url diff --git a/crates/zeph-llm/src/provider.rs b/crates/zeph-llm/src/provider.rs index 3defb66..a461be7 100644 --- a/crates/zeph-llm/src/provider.rs +++ b/crates/zeph-llm/src/provider.rs @@ -123,11 +123,14 @@ pub enum MessagePart { #[serde(default)] is_error: bool, }, - Image { - #[serde(with = "serde_bytes_base64")] - data: Vec, - mime_type: String, - }, + Image(Box), +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ImageData { + #[serde(with = "serde_bytes_base64")] + pub data: Vec, + pub mime_type: String, } mod serde_bytes_base64 { @@ -221,8 +224,8 @@ impl Message { } => { let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}"); } - MessagePart::Image { data, mime_type } => { - let _ = write!(out, "[image: {mime_type}, {} bytes]", data.len()); + MessagePart::Image(img) => { + let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len()); } } } @@ -1017,16 +1020,16 @@ mod tests { #[test] fn message_part_image_roundtrip_json() { - let part = MessagePart::Image { + let part = MessagePart::Image(Box::new(ImageData { data: vec![1, 2, 3, 4], mime_type: "image/jpeg".into(), - }; + })); let json = serde_json::to_string(&part).unwrap(); let decoded: MessagePart = serde_json::from_str(&json).unwrap(); match decoded { - MessagePart::Image { data, mime_type } => { - assert_eq!(data, vec![1, 2, 3, 4]); - assert_eq!(mime_type, "image/jpeg"); + MessagePart::Image(img) => { + assert_eq!(img.data, vec![1, 2, 3, 4]); + assert_eq!(img.mime_type, "image/jpeg"); } _ => panic!("expected Image variant"), } @@ -1040,10 +1043,10 @@ mod tests { MessagePart::Text { text: "see this".into(), }, - MessagePart::Image { + MessagePart::Image(Box::new(ImageData { data: vec![0u8; 100], mime_type: "image/png".into(), - }, + })), ], ); let content = msg.to_llm_content(); diff --git a/crates/zeph-tools/src/executor.rs b/crates/zeph-tools/src/executor.rs index 47c46e6..90655be 100644 --- a/crates/zeph-tools/src/executor.rs +++ b/crates/zeph-tools/src/executor.rs @@ -199,6 +199,59 @@ pub trait ToolExecutor: Send + Sync { } } +/// Object-safe erased version of [`ToolExecutor`] using boxed futures. +/// +/// Implemented automatically for all `T: ToolExecutor + 'static`. +/// Use `Box` when dynamic dispatch is required. +pub trait ErasedToolExecutor: Send + Sync { + fn execute_erased<'a>( + &'a self, + response: &'a str, + ) -> std::pin::Pin, ToolError>> + Send + 'a>>; + + fn execute_confirmed_erased<'a>( + &'a self, + response: &'a str, + ) -> std::pin::Pin, ToolError>> + Send + 'a>>; + + fn tool_definitions_erased(&self) -> Vec; + + fn execute_tool_call_erased<'a>( + &'a self, + call: &'a ToolCall, + ) -> std::pin::Pin, ToolError>> + Send + 'a>>; +} + +impl ErasedToolExecutor for T { + fn execute_erased<'a>( + &'a self, + response: &'a str, + ) -> std::pin::Pin, ToolError>> + Send + 'a>> + { + Box::pin(self.execute(response)) + } + + fn execute_confirmed_erased<'a>( + &'a self, + response: &'a str, + ) -> std::pin::Pin, ToolError>> + Send + 'a>> + { + Box::pin(self.execute_confirmed(response)) + } + + fn tool_definitions_erased(&self) -> Vec { + self.tool_definitions() + } + + fn execute_tool_call_erased<'a>( + &'a self, + call: &'a ToolCall, + ) -> std::pin::Pin, ToolError>> + Send + 'a>> + { + Box::pin(self.execute_tool_call(call)) + } +} + /// Extract fenced code blocks with the given language marker from text. /// /// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content. diff --git a/crates/zeph-tools/src/lib.rs b/crates/zeph-tools/src/lib.rs index eba6339..3ee17a2 100644 --- a/crates/zeph-tools/src/lib.rs +++ b/crates/zeph-tools/src/lib.rs @@ -19,8 +19,8 @@ pub use audit::{AuditEntry, AuditLogger, AuditResult}; pub use composite::CompositeExecutor; pub use config::{AuditConfig, ScrapeConfig, ShellConfig, ToolsConfig}; pub use executor::{ - DiffData, FilterStats, MAX_TOOL_OUTPUT_CHARS, ToolCall, ToolError, ToolEvent, ToolEventTx, - ToolExecutor, ToolOutput, truncate_tool_output, + DiffData, ErasedToolExecutor, FilterStats, MAX_TOOL_OUTPUT_CHARS, ToolCall, ToolError, + ToolEvent, ToolEventTx, ToolExecutor, ToolOutput, truncate_tool_output, }; pub use file::FileExecutor; pub use filter::{ diff --git a/src/main.rs b/src/main.rs index 6126e8b..79e69f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1072,14 +1072,13 @@ async fn create_channel(config: &Config) -> anyhow::Result { } #[cfg(feature = "scheduler")] -async fn bootstrap_scheduler( - agent: zeph_core::agent::Agent, +async fn bootstrap_scheduler( + agent: zeph_core::agent::Agent, config: &Config, shutdown_rx: watch::Receiver, -) -> zeph_core::agent::Agent +) -> zeph_core::agent::Agent where C: zeph_core::channel::Channel, - T: zeph_tools::executor::ToolExecutor, { if !config.scheduler.enabled { if config.agent.auto_update_check { diff --git a/tests/performance_agent_integration.rs b/tests/performance_agent_integration.rs index 5dd734b..907197b 100644 --- a/tests/performance_agent_integration.rs +++ b/tests/performance_agent_integration.rs @@ -341,7 +341,7 @@ async fn integration_agent_tool_executor_types() { let executor = ShellExecutor::new(&shell_config); // Should compile and construct successfully - let _agent: Agent = Agent::new( + let _agent: Agent = Agent::new( provider, channel, SkillRegistry::default(),