diff --git a/CHANGELOG.md b/CHANGELOG.md index f665f53d..0f794256 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ## [Unreleased] ### Added +- Anthropic prompt caching with structured system content blocks (#337) +- Configurable summary provider for tool output summarization via local model (#338) +- Aggressive inline pruning of stale tool outputs in tool loops (#339) +- Cache usage metrics (cache_read_tokens, cache_creation_tokens) in MetricsSnapshot (#340) - Native tool_use support for Claude provider (Anthropic API format) (#256) - Native function calling support for OpenAI provider (#257) - `ToolDefinition`, `ChatResponse`, `ToolUseRequest` types in zeph-llm (#254) diff --git a/config/default.toml b/config/default.toml index 5f961583..d175bc2d 100644 --- a/config/default.toml +++ b/config/default.toml @@ -3,6 +3,9 @@ name = "Zeph" # Maximum tool execution iterations per user message (doom-loop protection) max_tool_iterations = 10 +# Optional local model for tool output summarization and context compaction. +# Format: "ollama/". Falls back to primary provider if unset. +# summary_model = "ollama/llama3.2" [llm] # LLM provider: "ollama" for local models or "claude" for Claude API diff --git a/crates/zeph-core/src/agent/context.rs b/crates/zeph-core/src/agent/context.rs index bef1a675..57b97834 100644 --- a/crates/zeph-core/src/agent/context.rs +++ b/crates/zeph-core/src/agent/context.rs @@ -84,7 +84,7 @@ impl Agent Agent usize { + if self.messages.len() <= keep_recent + 1 { + return 0; + } + let boundary = self.messages.len().saturating_sub(keep_recent); + let mut freed = 0usize; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + .cast_signed(); + // Skip system prompt (index 0), prune from 1..boundary + for msg in &mut self.messages[1..boundary] { + let mut modified = false; + for part in &mut msg.parts { + match part { + MessagePart::ToolOutput { + body, compacted_at, .. + } if compacted_at.is_none() && !body.is_empty() => { + freed += estimate_tokens(body); + *compacted_at = Some(now); + *body = String::new(); + modified = true; + } + MessagePart::ToolResult { content, .. } if estimate_tokens(content) > 20 => { + freed += estimate_tokens(content); + "[pruned]".clone_into(content); + freed -= 1; + modified = true; + } + _ => {} + } + } + if modified { + msg.rebuild_content(); + } + } + if freed > 0 { + self.update_metrics(|m| m.tool_output_prunes += 1); + tracing::debug!( + freed, + boundary, + keep_recent, + "inline pruned stale tool outputs" + ); + } + freed + } + /// Two-tier compaction: Tier 1 prunes tool outputs, Tier 2 falls back to full LLM compaction. #[allow( clippy::cast_precision_loss, @@ -674,6 +726,9 @@ impl Agent"); + system_prompt.push_str("\n"); + #[cfg(feature = "mcp")] self.append_mcp_prompt(query, &mut system_prompt).await; @@ -1505,4 +1560,110 @@ mod tests { assert!(800 >= threshold); // at threshold → should stop assert!(799 < threshold); // below threshold → should continue } + + #[test] + fn prune_stale_tool_outputs_clears_old() { + let provider = MockProvider::new(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + let (tx, rx) = watch::channel(crate::metrics::MetricsSnapshot::default()); + + let mut agent = Agent::new(provider, channel, registry, None, 5, executor) + .with_context_budget(10000, 0.20, 0.75, 4, 0) + .with_metrics(tx); + + // Add 6 messages with tool outputs + for i in 0..6 { + agent.messages.push(Message::from_parts( + Role::User, + vec![MessagePart::ToolOutput { + tool_name: format!("tool_{i}"), + body: "x".repeat(200), + compacted_at: None, + }], + )); + } + // 7 messages total (1 system + 6 user) + + let freed = agent.prune_stale_tool_outputs(4); + assert!(freed > 0); + assert_eq!(rx.borrow().tool_output_prunes, 1); + + // Messages 1..3 should be pruned (boundary = 7-4=3) + for i in 1..3 { + if let MessagePart::ToolOutput { + body, compacted_at, .. + } = &agent.messages[i].parts[0] + { + assert!(body.is_empty(), "message {i} should be pruned"); + assert!(compacted_at.is_some()); + } + } + // Messages 3..6 should be untouched + for i in 3..7 { + if let MessagePart::ToolOutput { body, .. } = &agent.messages[i].parts[0] { + assert!(!body.is_empty(), "message {i} should be kept"); + } + } + } + + #[test] + fn prune_stale_tool_outputs_noop_when_few_messages() { + let provider = MockProvider::new(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let mut agent = Agent::new(provider, channel, registry, None, 5, executor); + + agent.messages.push(Message::from_parts( + Role::User, + vec![MessagePart::ToolOutput { + tool_name: "bash".into(), + body: "output".into(), + compacted_at: None, + }], + )); + + let freed = agent.prune_stale_tool_outputs(4); + assert_eq!(freed, 0); + } + + #[test] + fn prune_stale_prunes_tool_result_too() { + let provider = MockProvider::new(vec![]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let mut agent = Agent::new(provider, channel, registry, None, 5, executor); + + // Add old message with large ToolResult + agent.messages.push(Message::from_parts( + Role::User, + vec![MessagePart::ToolResult { + tool_use_id: "t1".into(), + content: "x".repeat(500), + is_error: false, + }], + )); + // Add 4 recent messages + for _ in 0..4 { + agent.messages.push(Message { + role: Role::User, + content: "recent".into(), + parts: vec![], + }); + } + + let freed = agent.prune_stale_tool_outputs(4); + assert!(freed > 0); + + if let MessagePart::ToolResult { content, .. } = &agent.messages[1].parts[0] { + assert_eq!(content, "[pruned]"); + } else { + panic!("expected ToolResult"); + } + } } diff --git a/crates/zeph-core/src/agent/mod.rs b/crates/zeph-core/src/agent/mod.rs index e830ec8d..618596f5 100644 --- a/crates/zeph-core/src/agent/mod.rs +++ b/crates/zeph-core/src/agent/mod.rs @@ -33,6 +33,7 @@ use crate::config_watcher::ConfigEvent; use crate::context::{ContextBudget, EnvironmentContext, build_system_prompt}; const DOOM_LOOP_WINDOW: usize = 3; +const TOOL_LOOP_KEEP_RECENT: usize = 4; const MAX_QUEUE_SIZE: usize = 10; const MESSAGE_MERGE_WINDOW: Duration = Duration::from_millis(500); const RECALL_PREFIX: &str = "[semantic recall]\n"; @@ -120,6 +121,7 @@ pub struct Agent start_time: Instant, message_queue: VecDeque, summarize_tool_output_enabled: bool, + summary_provider: Option

, permission_policy: zeph_tools::PermissionPolicy, warmup_ready: Option>, max_tool_iterations: usize, @@ -209,6 +211,7 @@ impl Agent Agent Self { + self.summary_provider = Some(provider); + self + } + + fn summary_or_primary_provider(&self) -> &P { + self.summary_provider.as_ref().unwrap_or(&self.provider) + } + #[must_use] pub fn with_permission_policy(mut self, policy: zeph_tools::PermissionPolicy) -> Self { self.permission_policy = policy; @@ -417,6 +430,15 @@ impl Agent Agent { pub(crate) async fn process_response(&mut self) -> Result<(), super::error::AgentError> { @@ -84,6 +84,9 @@ impl Agent Agent Agent Agent format!("[tool output summary]\n```\n{summary}\n```"), Err(e) => { tracing::warn!( @@ -428,6 +433,9 @@ impl Agent Agent, } #[derive(Debug, Deserialize)] @@ -612,6 +614,7 @@ impl Config { agent: AgentConfig { name: "Zeph".into(), max_tool_iterations: 10, + summary_model: None, }, llm: LlmConfig { provider: "ollama".into(), diff --git a/crates/zeph-core/src/metrics.rs b/crates/zeph-core/src/metrics.rs index af75e200..95692fb9 100644 --- a/crates/zeph-core/src/metrics.rs +++ b/crates/zeph-core/src/metrics.rs @@ -23,6 +23,8 @@ pub struct MetricsSnapshot { pub summaries_count: u64, pub context_compactions: u64, pub tool_output_prunes: u64, + pub cache_read_tokens: u64, + pub cache_creation_tokens: u64, } pub struct MetricsCollector { diff --git a/crates/zeph-llm/src/any.rs b/crates/zeph-llm/src/any.rs index f01b3d73..a5610ff0 100644 --- a/crates/zeph-llm/src/any.rs +++ b/crates/zeph-llm/src/any.rs @@ -98,6 +98,10 @@ impl LlmProvider for AnyProvider { ) -> Result { delegate_provider!(self, |p| p.chat_with_tools(messages, tools).await) } + + fn last_cache_usage(&self) -> Option<(u64, u64)> { + delegate_provider!(self, |p| p.last_cache_usage()) + } } #[cfg(test)] diff --git a/crates/zeph-llm/src/claude.rs b/crates/zeph-llm/src/claude.rs index 3a2fcd28..1f848464 100644 --- a/crates/zeph-llm/src/claude.rs +++ b/crates/zeph-llm/src/claude.rs @@ -13,15 +13,21 @@ use crate::provider::{ const API_URL: &str = "https://api.anthropic.com/v1/messages"; const ANTHROPIC_VERSION: &str = "2023-06-01"; +const ANTHROPIC_BETA: &str = "prompt-caching-2024-07-31"; const MAX_RETRIES: u32 = 3; const BASE_BACKOFF_SECS: u64 = 1; +const CACHE_MARKER_STABLE: &str = ""; +const CACHE_MARKER_TOOLS: &str = ""; +const CACHE_MARKER_VOLATILE: &str = ""; + pub struct ClaudeProvider { client: reqwest::Client, api_key: String, model: String, max_tokens: u32, pub(crate) status_tx: Option, + last_cache: std::sync::Mutex>, } impl fmt::Debug for ClaudeProvider { @@ -32,6 +38,7 @@ impl fmt::Debug for ClaudeProvider { .field("model", &self.model) .field("max_tokens", &self.max_tokens) .field("status_tx", &self.status_tx.is_some()) + .field("last_cache", &self.last_cache.lock().ok()) .finish() } } @@ -44,6 +51,7 @@ impl Clone for ClaudeProvider { model: self.model.clone(), max_tokens: self.max_tokens, status_tx: self.status_tx.clone(), + last_cache: std::sync::Mutex::new(None), } } } @@ -57,6 +65,7 @@ impl ClaudeProvider { model, max_tokens, status_tx: None, + last_cache: std::sync::Mutex::new(None), } } @@ -66,6 +75,15 @@ impl ClaudeProvider { self } + fn store_cache_usage(&self, usage: &ApiUsage) { + if let Ok(mut guard) = self.last_cache.lock() { + *guard = Some(( + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + )); + } + } + fn emit_status(&self, msg: impl Into) { if let Some(ref tx) = self.status_tx { let _ = tx.send(msg.into()); @@ -74,11 +92,12 @@ impl ClaudeProvider { fn build_request(&self, messages: &[Message], stream: bool) -> reqwest::RequestBuilder { let (system, chat_messages) = split_messages(messages); + let system_blocks = system.map(|s| split_system_into_blocks(&s)); let body = RequestBody { model: &self.model, max_tokens: self.max_tokens, - system: system.as_deref(), + system: system_blocks, messages: &chat_messages, stream, }; @@ -87,6 +106,7 @@ impl ClaudeProvider { .post(API_URL) .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) + .header("anthropic-beta", ANTHROPIC_BETA) .header("content-type", "application/json") .json(&body) } @@ -129,6 +149,11 @@ impl ClaudeProvider { let resp: ApiResponse = serde_json::from_str(&text)?; + if let Some(ref usage) = resp.usage { + log_cache_usage(usage); + self.store_cache_usage(usage); + } + return resp .content .first() @@ -233,6 +258,10 @@ impl LlmProvider for ClaudeProvider { true } + fn last_cache_usage(&self) -> Option<(u64, u64)> { + self.last_cache.lock().ok().and_then(|g| *g) + } + async fn chat_with_tools( &self, messages: &[Message], @@ -248,10 +277,11 @@ impl LlmProvider for ClaudeProvider { }) .collect(); + let system_blocks = system.map(|s| split_system_into_blocks(&s)); let body = ToolRequestBody { model: &self.model, max_tokens: self.max_tokens, - system: system.as_deref(), + system: system_blocks, messages: &chat_messages, tools: &api_tools, }; @@ -262,6 +292,7 @@ impl LlmProvider for ClaudeProvider { .post(API_URL) .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) + .header("anthropic-beta", ANTHROPIC_BETA) .header("content-type", "application/json") .json(&body) .send() @@ -296,6 +327,10 @@ impl LlmProvider for ClaudeProvider { tracing::debug!(raw_response = %text, "Claude chat_with_tools response"); let resp: ToolApiResponse = serde_json::from_str(&text)?; + if let Some(ref usage) = resp.usage { + log_cache_usage(usage); + self.store_cache_usage(usage); + } let parsed = parse_tool_response(resp); tracing::debug!(?parsed, "parsed ChatResponse"); return Ok(parsed); @@ -315,6 +350,16 @@ fn retry_delay(response: &reqwest::Response, attempt: u32) -> Duration { Duration::from_secs(BASE_BACKOFF_SECS << attempt) } +fn log_cache_usage(usage: &ApiUsage) { + tracing::debug!( + input_tokens = usage.input_tokens, + output_tokens = usage.output_tokens, + cache_creation = usage.cache_creation_input_tokens, + cache_read = usage.cache_read_input_tokens, + "Claude API usage" + ); +} + fn parse_sse_event(data: &str, event_type: &str) -> Option> { match event_type { "content_block_delta" => match serde_json::from_str::(data) { @@ -379,6 +424,88 @@ fn split_messages(messages: &[Message]) -> (Option, Vec>) (system, chat) } +#[derive(Serialize, Clone, Debug)] +struct SystemContentBlock { + #[serde(rename = "type")] + block_type: &'static str, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, +} + +#[derive(Serialize, Clone, Debug)] +struct CacheControl { + #[serde(rename = "type")] + cache_type: &'static str, +} + +fn split_system_into_blocks(system: &str) -> Vec { + // Split on volatile marker first: everything before is cacheable + let (cacheable_part, volatile_part) = if let Some(pos) = system.find(CACHE_MARKER_VOLATILE) { + ( + &system[..pos], + Some(&system[pos + CACHE_MARKER_VOLATILE.len()..]), + ) + } else { + (system, None) + }; + + let mut blocks = Vec::new(); + let cache_markers = [CACHE_MARKER_STABLE, CACHE_MARKER_TOOLS]; + let mut remaining = cacheable_part; + + for marker in &cache_markers { + if let Some(pos) = remaining.find(marker) { + let before = remaining[..pos].trim(); + if !before.is_empty() { + blocks.push(SystemContentBlock { + block_type: "text", + text: before.to_owned(), + cache_control: Some(CacheControl { + cache_type: "ephemeral", + }), + }); + } + remaining = &remaining[pos + marker.len()..]; + } + } + + let remaining = remaining.trim(); + if !remaining.is_empty() { + blocks.push(SystemContentBlock { + block_type: "text", + text: remaining.to_owned(), + cache_control: Some(CacheControl { + cache_type: "ephemeral", + }), + }); + } + + if let Some(volatile) = volatile_part { + let volatile = volatile.trim(); + if !volatile.is_empty() { + blocks.push(SystemContentBlock { + block_type: "text", + text: volatile.to_owned(), + cache_control: None, + }); + } + } + + // No markers at all: cache the entire prompt as one block + if blocks.is_empty() { + blocks.push(SystemContentBlock { + block_type: "text", + text: system.to_owned(), + cache_control: Some(CacheControl { + cache_type: "ephemeral", + }), + }); + } + + blocks +} + #[derive(Serialize)] struct AnthropicTool<'a> { name: &'a str, @@ -391,7 +518,7 @@ struct ToolRequestBody<'a> { model: &'a str, max_tokens: u32, #[serde(skip_serializing_if = "Option::is_none")] - system: Option<&'a str>, + system: Option>, messages: &'a [StructuredApiMessage], tools: &'a [AnthropicTool<'a>], } @@ -431,6 +558,8 @@ enum AnthropicContentBlock { #[derive(Deserialize)] struct ToolApiResponse { content: Vec, + #[serde(default)] + usage: Option, } fn parse_tool_response(resp: ToolApiResponse) -> ChatResponse { @@ -481,6 +610,7 @@ fn split_messages_structured(messages: &[Message]) -> (Option, Vec (Option, Vec { + MessagePart::ToolUse { id, name, input } if is_assistant => { blocks.push(AnthropicContentBlock::ToolUse { id: id.clone(), name: name.clone(), input: input.clone(), }); } + MessagePart::ToolUse { name, input, .. } => { + blocks.push(AnthropicContentBlock::Text { + text: format!("[tool_use: {name}] {input}"), + }); + } MessagePart::ToolResult { tool_use_id, content, is_error, - } => { + } if !is_assistant => { blocks.push(AnthropicContentBlock::ToolResult { tool_use_id: tool_use_id.clone(), content: content.clone(), is_error: *is_error, }); } + MessagePart::ToolResult { content, .. } => { + blocks.push(AnthropicContentBlock::Text { + text: content.clone(), + }); + } } } chat.push(StructuredApiMessage { @@ -548,7 +688,7 @@ struct RequestBody<'a> { model: &'a str, max_tokens: u32, #[serde(skip_serializing_if = "Option::is_none")] - system: Option<&'a str>, + system: Option>, messages: &'a [ApiMessage<'a>], #[serde(skip_serializing_if = "std::ops::Not::not")] stream: bool, @@ -563,6 +703,21 @@ struct ApiMessage<'a> { #[derive(Deserialize)] struct ApiResponse { content: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Deserialize, Debug)] +#[allow(clippy::struct_field_names)] +struct ApiUsage { + #[serde(default)] + input_tokens: u64, + #[serde(default)] + output_tokens: u64, + #[serde(default)] + cache_creation_input_tokens: u64, + #[serde(default)] + cache_read_input_tokens: u64, } #[derive(Deserialize)] @@ -820,16 +975,24 @@ mod tests { } #[test] - fn request_body_serializes_with_system() { + fn request_body_serializes_with_system_blocks() { let body = RequestBody { model: "claude-sonnet-4-5-20250929", max_tokens: 1024, - system: Some("You are helpful."), + system: Some(vec![SystemContentBlock { + block_type: "text", + text: "You are helpful.".into(), + cache_control: Some(CacheControl { + cache_type: "ephemeral", + }), + }]), messages: &[], stream: false, }; let json = serde_json::to_string(&body).unwrap(); - assert!(json.contains("\"system\":\"You are helpful.\"")); + assert!(json.contains("\"system\"")); + assert!(json.contains("You are helpful.")); + assert!(json.contains("\"cache_control\"")); } #[test] @@ -1092,6 +1255,60 @@ mod tests { assert!(!json.contains("stream")); } + #[test] + fn split_system_no_markers_caches_entire_block() { + let blocks = split_system_into_blocks("You are Zeph, an AI assistant."); + assert_eq!(blocks.len(), 1); + assert!(blocks[0].cache_control.is_some()); + assert!(blocks[0].text.contains("Zeph")); + } + + #[test] + fn split_system_with_all_markers() { + let system = format!( + "base prompt\n{CACHE_MARKER_STABLE}\nskills here\n\ + {CACHE_MARKER_TOOLS}\ntool catalog\n\ + {CACHE_MARKER_VOLATILE}\nvolatile stuff" + ); + let blocks = split_system_into_blocks(&system); + assert_eq!(blocks.len(), 4); + assert!(blocks[0].cache_control.is_some()); + assert!(blocks[0].text.contains("base prompt")); + assert!(blocks[1].cache_control.is_some()); + assert!(blocks[1].text.contains("skills here")); + assert!(blocks[2].cache_control.is_some()); + assert!(blocks[2].text.contains("tool catalog")); + assert!(blocks[3].cache_control.is_none()); + assert!(blocks[3].text.contains("volatile stuff")); + } + + #[test] + fn split_system_partial_markers() { + let system = format!("base prompt\n{CACHE_MARKER_VOLATILE}\nvolatile only"); + let blocks = split_system_into_blocks(&system); + assert_eq!(blocks.len(), 2); + assert!(blocks[0].cache_control.is_some()); + assert!(blocks[1].cache_control.is_none()); + } + + #[test] + fn api_usage_deserialization() { + let json = r#"{"input_tokens":100,"output_tokens":50,"cache_creation_input_tokens":1000,"cache_read_input_tokens":900}"#; + let usage: ApiUsage = serde_json::from_str(json).unwrap(); + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.cache_creation_input_tokens, 1000); + assert_eq!(usage.cache_read_input_tokens, 900); + } + + #[test] + fn api_response_with_usage() { + let json = r#"{"content":[{"text":"Hello"}],"usage":{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}"#; + let resp: ApiResponse = serde_json::from_str(json).unwrap(); + assert!(resp.usage.is_some()); + assert_eq!(resp.usage.unwrap().input_tokens, 10); + } + #[test] fn api_response_deserializes() { let json = r#"{"content":[{"text":"Hello world"}]}"#; @@ -1203,6 +1420,7 @@ mod tests { content: vec![AnthropicContentBlock::Text { text: "Hello".into(), }], + usage: None, }; let result = parse_tool_response(resp); assert!(matches!(result, ChatResponse::Text(s) if s == "Hello")); @@ -1221,6 +1439,7 @@ mod tests { input: serde_json::json!({"command": "ls"}), }, ], + usage: None, }; let result = parse_tool_response(resp); if let ChatResponse::ToolUse { text, tool_calls } = result { @@ -1241,6 +1460,7 @@ mod tests { name: "read".into(), input: serde_json::json!({"path": "/tmp/file.txt"}), }], + usage: None, }; let result = parse_tool_response(resp); if let ChatResponse::ToolUse { text, tool_calls } = result { diff --git a/crates/zeph-llm/src/orchestrator/mod.rs b/crates/zeph-llm/src/orchestrator/mod.rs index f6f70698..30a1e343 100644 --- a/crates/zeph-llm/src/orchestrator/mod.rs +++ b/crates/zeph-llm/src/orchestrator/mod.rs @@ -232,6 +232,12 @@ impl LlmProvider for ModelOrchestrator { provider.chat_with_tools(messages, tools).await } + fn last_cache_usage(&self) -> Option<(u64, u64)> { + self.providers + .get(&self.default_provider) + .and_then(LlmProvider::last_cache_usage) + } + fn name(&self) -> &'static str { "orchestrator" } diff --git a/crates/zeph-llm/src/orchestrator/router.rs b/crates/zeph-llm/src/orchestrator/router.rs index 64b83fad..9ea26f50 100644 --- a/crates/zeph-llm/src/orchestrator/router.rs +++ b/crates/zeph-llm/src/orchestrator/router.rs @@ -118,6 +118,17 @@ impl LlmProvider for SubProvider { } } + fn last_cache_usage(&self) -> Option<(u64, u64)> { + match self { + Self::Ollama(p) => p.last_cache_usage(), + Self::Claude(p) => p.last_cache_usage(), + #[cfg(feature = "openai")] + Self::OpenAi(p) => p.last_cache_usage(), + #[cfg(feature = "candle")] + Self::Candle(p) => p.last_cache_usage(), + } + } + fn name(&self) -> &'static str { match self { Self::Ollama(p) => p.name(), diff --git a/crates/zeph-llm/src/provider.rs b/crates/zeph-llm/src/provider.rs index b6d405d2..aef62b3c 100644 --- a/crates/zeph-llm/src/provider.rs +++ b/crates/zeph-llm/src/provider.rs @@ -233,6 +233,12 @@ pub trait LlmProvider: Send + Sync { ) -> Result { Ok(ChatResponse::Text(self.chat(messages).await?)) } + + /// Return the cache usage from the last API call, if available. + /// Returns `(cache_creation_tokens, cache_read_tokens)`. + fn last_cache_usage(&self) -> Option<(u64, u64)> { + None + } } #[cfg(test)] diff --git a/crates/zeph-tui/src/widgets/resources.rs b/crates/zeph-tui/src/widgets/resources.rs index 14e98dc4..162092e7 100644 --- a/crates/zeph-tui/src/widgets/resources.rs +++ b/crates/zeph-tui/src/widgets/resources.rs @@ -9,7 +9,7 @@ use crate::theme::Theme; pub fn render(metrics: &MetricsSnapshot, frame: &mut Frame, area: Rect) { let theme = Theme::default(); - let res_lines = vec![ + let mut res_lines = vec![ Line::from(format!(" Provider: {}", metrics.provider_name)), Line::from(format!(" Model: {}", metrics.model_name)), Line::from(format!(" Context: {}", metrics.context_tokens)), @@ -17,6 +17,16 @@ pub fn render(metrics: &MetricsSnapshot, frame: &mut Frame, area: Rect) { Line::from(format!(" API calls: {}", metrics.api_calls)), Line::from(format!(" Latency: {}ms", metrics.last_llm_latency_ms)), ]; + if metrics.cache_creation_tokens > 0 || metrics.cache_read_tokens > 0 { + res_lines.push(Line::from(format!( + " Cache write: {}", + metrics.cache_creation_tokens + ))); + res_lines.push(Line::from(format!( + " Cache read: {}", + metrics.cache_read_tokens + ))); + } let resources = Paragraph::new(res_lines).block( Block::default() .borders(Borders::ALL) diff --git a/src/main.rs b/src/main.rs index 37fb8bf2..eccdba07 100644 --- a/src/main.rs +++ b/src/main.rs @@ -363,6 +363,19 @@ async fn main() -> anyhow::Result<()> { let index_provider = provider.clone(); let warmup_provider_clone = provider.clone(); + let summary_provider = config.agent.summary_model.as_ref().and_then(|model_spec| { + match create_summary_provider(model_spec, &config) { + Ok(sp) => { + tracing::info!(model = %model_spec, "summary provider configured"); + Some(sp) + } + Err(e) => { + tracing::warn!("failed to create summary provider: {e:#}, using primary"); + None + } + } + }); + let agent = Agent::new( provider, channel, @@ -395,6 +408,12 @@ async fn main() -> anyhow::Result<()> { .with_permission_policy(permission_policy.clone()) .with_config_reload(config_path.clone(), config_reload_rx); + let agent = if let Some(sp) = summary_provider { + agent.with_summary_provider(sp) + } else { + agent + }; + #[cfg(feature = "index")] let mut _index_watcher: Option = None; #[cfg(feature = "index")] @@ -839,6 +858,16 @@ fn create_provider(config: &Config) -> anyhow::Result { } } +fn create_summary_provider(model_spec: &str, config: &Config) -> anyhow::Result { + if let Some(model) = model_spec.strip_prefix("ollama/") { + let base_url = &config.llm.base_url; + let provider = OllamaProvider::new(base_url, model.to_owned(), String::new()); + Ok(AnyProvider::Ollama(provider)) + } else { + bail!("unsupported summary_model format: {model_spec} (expected 'ollama/')") + } +} + #[cfg(feature = "a2a")] fn spawn_a2a_server( config: &Config,