diff --git a/CHANGELOG.md b/CHANGELOG.md index 54410a76..ddbbb67c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - TUI test automation infrastructure: EventSource trait abstraction, insta widget snapshot tests, TestBackend integration tests, proptest layout verification, expectrl E2E terminal tests (#542) - CI snapshot regression pipeline with `cargo insta test --check` (#547) - Pipeline API with composable, type-safe `Step` trait, `Pipeline` builder, `ParallelStep` combinator, and built-in steps (`LlmStep`, `RetrievalStep`, `ExtractStep`, `MapStep`) (#466, #467, #468) +- Structured intent classification for skill disambiguation: when top-2 skill scores are within `disambiguation_threshold` (default 0.05), agent calls LLM via `chat_typed::()` to select the best-matching skill (#550) +- `ScoredMatch` struct exposing both skill index and cosine similarity score from matcher backends +- `IntentClassification` type (`skill_name`, `confidence`, `params`) with `JsonSchema` derive for schema-enforced LLM responses +- `disambiguation_threshold` in `[skills]` config section (default: 0.05) with `with_disambiguation_threshold()` builder on `Agent` ## [0.10.0] - 2026-02-18 diff --git a/Cargo.lock b/Cargo.lock index 8df07796..71c757e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8661,6 +8661,7 @@ dependencies = [ "notify", "notify-debouncer-mini", "qdrant-client", + "schemars 1.2.1", "serde", "serde_json", "tempfile", diff --git a/README.md b/README.md index 30bf8da7..f076d0a5 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,9 @@ This is the core idea behind Zeph. Every byte that enters the LLM context window Most frameworks inject all tool descriptions into every prompt. 50 tools installed? 50 descriptions in every request. -Zeph embeds skills and MCP tools as vectors at startup (concurrent embedding via `buffer_unordered`), then retrieves only the **top-K relevant** per query via cosine similarity. Install 500 skills — the prompt sees only the 5 that matter. [How skills work →](https://bug-ops.github.io/zeph/guide/skills.html) +Zeph embeds skills and MCP tools as vectors at startup (concurrent embedding via `buffer_unordered`), then retrieves only the **top-K relevant** per query via cosine similarity. Install 500 skills — the prompt sees only the 5 that matter. + +When two candidates score within a configurable threshold of each other, structured intent classification resolves the ambiguity: the agent calls the LLM with a typed `IntentClassification` schema and reorders candidates accordingly — no hallucination, no guessing. [How skills work →](https://bug-ops.github.io/zeph/guide/skills.html) ### Smart Output Filtering — 70-99% Token Savings diff --git a/crates/zeph-core/src/agent/context.rs b/crates/zeph-core/src/agent/context.rs index 40529968..e1811413 100644 --- a/crates/zeph-core/src/agent/context.rs +++ b/crates/zeph-core/src/agent/context.rs @@ -2,6 +2,8 @@ use std::fmt::Write; use zeph_llm::provider::MessagePart; use zeph_memory::semantic::estimate_tokens; +use zeph_skills::ScoredMatch; +use zeph_skills::loader::SkillMeta; use zeph_skills::prompt::format_skills_catalog; use super::{ @@ -628,12 +630,65 @@ impl Agent { Ok(()) } + async fn disambiguate_skills( + &self, + query: &str, + all_meta: &[&SkillMeta], + scored: &[ScoredMatch], + ) -> Option> { + let mut candidates = String::new(); + for sm in scored { + if let Some(meta) = all_meta.get(sm.index) { + let _ = writeln!( + candidates, + "- {} (score: {:.3}): {}", + meta.name, sm.score, meta.description + ); + } + } + + let prompt = format!( + "The user said: \"{query}\"\n\n\ + These skills matched with similar scores:\n{candidates}\n\ + Which skill best matches the user's intent? \ + Return the skill_name, your confidence (0-1), and any extracted parameters." + ); + + let messages = vec![Message::from_legacy(Role::User, prompt)]; + match self + .provider + .chat_typed::(&messages) + .await + { + Ok(classification) => { + tracing::info!( + skill = %classification.skill_name, + confidence = classification.confidence, + "disambiguation selected skill" + ); + let mut indices: Vec = scored.iter().map(|s| s.index).collect(); + if let Some(pos) = indices.iter().position(|&i| { + all_meta + .get(i) + .is_some_and(|m| m.name == classification.skill_name) + }) { + indices.swap(0, pos); + } + Some(indices) + } + Err(e) => { + tracing::warn!("disambiguation failed, using original order: {e:#}"); + None + } + } + } + #[allow(clippy::too_many_lines)] pub(super) async fn rebuild_system_prompt(&mut self, query: &str) { let all_meta = self.skill_state.registry.all_meta(); let matched_indices: Vec = if let Some(matcher) = &self.skill_state.matcher { let provider = self.provider.clone(); - matcher + let scored = matcher .match_skills( &all_meta, query, @@ -644,7 +699,18 @@ impl Agent { Box::pin(async move { p.embed(&owned).await }) }, ) - .await + .await; + + if scored.len() >= 2 + && (scored[0].score - scored[1].score) < self.skill_state.disambiguation_threshold + { + match self.disambiguate_skills(query, &all_meta, &scored).await { + Some(reordered) => reordered, + None => scored.iter().map(|s| s.index).collect(), + } + } else { + scored.iter().map(|s| s.index).collect() + } } else { (0..all_meta.len()).collect() }; @@ -1670,4 +1736,185 @@ mod tests { panic!("expected ToolResult"); } } + + #[tokio::test] + async fn disambiguate_skills_reorders_on_match() { + let json = r#"{"skill_name":"beta_skill","confidence":0.9,"params":{}}"#; + let provider = mock_provider(vec![json.to_string()]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let metas = vec![ + SkillMeta { + name: "alpha_skill".into(), + description: "does alpha".into(), + compatibility: None, + license: None, + metadata: Vec::new(), + allowed_tools: Vec::new(), + skill_dir: std::path::PathBuf::new(), + }, + SkillMeta { + name: "beta_skill".into(), + description: "does beta".into(), + compatibility: None, + license: None, + metadata: Vec::new(), + allowed_tools: Vec::new(), + skill_dir: std::path::PathBuf::new(), + }, + ]; + let refs: Vec<&SkillMeta> = metas.iter().collect(); + let scored = vec![ + ScoredMatch { + index: 0, + score: 0.90, + }, + ScoredMatch { + index: 1, + score: 0.88, + }, + ]; + + let result = agent + .disambiguate_skills("do beta stuff", &refs, &scored) + .await; + assert!(result.is_some()); + let indices = result.unwrap(); + assert_eq!(indices[0], 1); // beta_skill moved to front + } + + #[tokio::test] + async fn disambiguate_skills_returns_none_on_error() { + let provider = mock_provider_failing(); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let metas = vec![SkillMeta { + name: "test".into(), + description: "test".into(), + compatibility: None, + license: None, + metadata: Vec::new(), + allowed_tools: Vec::new(), + skill_dir: std::path::PathBuf::new(), + }]; + let refs: Vec<&SkillMeta> = metas.iter().collect(); + let scored = vec![ScoredMatch { + index: 0, + score: 0.5, + }]; + + let result = agent.disambiguate_skills("query", &refs, &scored).await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn disambiguate_skills_empty_candidates() { + let json = r#"{"skill_name":"none","confidence":0.1,"params":{}}"#; + let provider = mock_provider(vec![json.to_string()]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let metas: Vec = vec![]; + let refs: Vec<&SkillMeta> = metas.iter().collect(); + let scored: Vec = vec![]; + + let result = agent.disambiguate_skills("query", &refs, &scored).await; + assert!(result.is_some()); + assert!(result.unwrap().is_empty()); + } + + #[tokio::test] + async fn disambiguate_skills_unknown_skill_preserves_order() { + let json = r#"{"skill_name":"nonexistent","confidence":0.5,"params":{}}"#; + let provider = mock_provider(vec![json.to_string()]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let metas = vec![ + SkillMeta { + name: "first".into(), + description: "first skill".into(), + compatibility: None, + license: None, + metadata: Vec::new(), + allowed_tools: Vec::new(), + skill_dir: std::path::PathBuf::new(), + }, + SkillMeta { + name: "second".into(), + description: "second skill".into(), + compatibility: None, + license: None, + metadata: Vec::new(), + allowed_tools: Vec::new(), + skill_dir: std::path::PathBuf::new(), + }, + ]; + let refs: Vec<&SkillMeta> = metas.iter().collect(); + let scored = vec![ + ScoredMatch { + index: 0, + score: 0.9, + }, + ScoredMatch { + index: 1, + score: 0.88, + }, + ]; + + let result = agent + .disambiguate_skills("query", &refs, &scored) + .await + .unwrap(); + // No swap since LLM returned unknown name + assert_eq!(result[0], 0); + assert_eq!(result[1], 1); + } + + #[tokio::test] + async fn disambiguate_single_candidate_no_swap() { + let json = r#"{"skill_name":"only_skill","confidence":0.95,"params":{}}"#; + let provider = mock_provider(vec![json.to_string()]); + let channel = MockChannel::new(vec![]); + let registry = create_test_registry(); + let executor = MockToolExecutor::no_tools(); + + let agent = Agent::new(provider, channel, registry, None, 5, executor); + + let metas = vec![SkillMeta { + name: "only_skill".into(), + description: "the only one".into(), + compatibility: None, + license: None, + metadata: Vec::new(), + allowed_tools: Vec::new(), + skill_dir: std::path::PathBuf::new(), + }]; + let refs: Vec<&SkillMeta> = metas.iter().collect(); + let scored = vec![ScoredMatch { + index: 0, + score: 0.95, + }]; + + let result = agent + .disambiguate_skills("query", &refs, &scored) + .await + .unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], 0); + } } diff --git a/crates/zeph-core/src/agent/mod.rs b/crates/zeph-core/src/agent/mod.rs index 57a64598..109ac60b 100644 --- a/crates/zeph-core/src/agent/mod.rs +++ b/crates/zeph-core/src/agent/mod.rs @@ -70,6 +70,7 @@ pub(super) struct SkillState { pub(super) skill_paths: Vec, pub(super) matcher: Option, pub(super) max_active_skills: usize, + pub(super) disambiguation_threshold: f32, pub(super) embedding_model: String, pub(super) skill_reload_rx: Option>, pub(super) active_skill_names: Vec, @@ -182,6 +183,7 @@ impl Agent { skill_paths: Vec::new(), matcher, max_active_skills, + disambiguation_threshold: 0.05, embedding_model: String::new(), skill_reload_rx: None, active_skill_names: Vec::new(), @@ -267,6 +269,12 @@ impl Agent { self } + #[must_use] + pub fn with_disambiguation_threshold(mut self, threshold: f32) -> Self { + self.skill_state.disambiguation_threshold = threshold; + self + } + #[must_use] pub fn with_shutdown(mut self, rx: watch::Receiver) -> Self { self.shutdown = rx; @@ -858,6 +866,7 @@ impl Agent { self.memory_state.recall_limit = config.memory.semantic.recall_limit; self.memory_state.summarization_threshold = config.memory.summarization_threshold; self.skill_state.max_active_skills = config.skills.max_active_skills; + self.skill_state.disambiguation_threshold = config.skills.disambiguation_threshold; if config.memory.context_budget_tokens > 0 { self.context_state.budget = Some(ContextBudget::new( diff --git a/crates/zeph-core/src/config/types.rs b/crates/zeph-core/src/config/types.rs index 9ccaec1e..d62508b0 100644 --- a/crates/zeph-core/src/config/types.rs +++ b/crates/zeph-core/src/config/types.rs @@ -262,12 +262,18 @@ pub struct SkillsConfig { pub paths: Vec, #[serde(default = "default_max_active_skills")] pub max_active_skills: usize, + #[serde(default = "default_disambiguation_threshold")] + pub disambiguation_threshold: f32, #[serde(default)] pub learning: LearningConfig, #[serde(default)] pub trust: TrustConfig, } +fn default_disambiguation_threshold() -> f32 { + 0.05 +} + #[derive(Debug, Clone, Deserialize)] pub struct TrustConfig { #[serde(default = "default_trust_default_level")] @@ -967,6 +973,7 @@ impl Config { skills: SkillsConfig { paths: vec!["./skills".into()], max_active_skills: default_max_active_skills(), + disambiguation_threshold: default_disambiguation_threshold(), learning: LearningConfig::default(), trust: TrustConfig::default(), }, diff --git a/crates/zeph-llm/Cargo.toml b/crates/zeph-llm/Cargo.toml index 16c7ce3f..838a0857 100644 --- a/crates/zeph-llm/Cargo.toml +++ b/crates/zeph-llm/Cargo.toml @@ -23,8 +23,8 @@ futures-core.workspace = true hf-hub = { workspace = true, optional = true } ollama-rs.workspace = true reqwest = { workspace = true, features = ["json", "rustls", "stream"] } -serde = { workspace = true, features = ["derive"] } schemars.workspace = true +serde = { workspace = true, features = ["derive"] } serde_json.workspace = true tokenizers = { workspace = true, optional = true } tokio = { workspace = true, features = ["rt", "sync", "time"] } diff --git a/crates/zeph-skills/Cargo.toml b/crates/zeph-skills/Cargo.toml index e0b9d317..099104fe 100644 --- a/crates/zeph-skills/Cargo.toml +++ b/crates/zeph-skills/Cargo.toml @@ -17,6 +17,7 @@ qdrant-client = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true futures.workspace = true +schemars.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["sync", "rt", "time"] } tracing.workspace = true diff --git a/crates/zeph-skills/src/lib.rs b/crates/zeph-skills/src/lib.rs index 786db062..577e5c54 100644 --- a/crates/zeph-skills/src/lib.rs +++ b/crates/zeph-skills/src/lib.rs @@ -12,4 +12,5 @@ pub mod trust; pub mod watcher; pub use error::SkillError; +pub use matcher::{IntentClassification, ScoredMatch}; pub use trust::{SkillSource, SkillTrust, TrustLevel, compute_skill_hash}; diff --git a/crates/zeph-skills/src/matcher.rs b/crates/zeph-skills/src/matcher.rs index 20bc76b5..d4a6a29c 100644 --- a/crates/zeph-skills/src/matcher.rs +++ b/crates/zeph-skills/src/matcher.rs @@ -1,11 +1,29 @@ +use std::collections::HashMap; use std::time::Duration; +use schemars::JsonSchema; +use serde::Deserialize; + use crate::error::SkillError; use crate::loader::SkillMeta; use futures::stream::{self, StreamExt}; pub use zeph_llm::provider::EmbedFuture; +#[derive(Debug, Clone)] +pub struct ScoredMatch { + pub index: usize, + pub score: f32, +} + +#[derive(Debug, Clone, Deserialize, JsonSchema)] +pub struct IntentClassification { + pub skill_name: String, + pub confidence: f32, + #[serde(default)] + pub params: HashMap, +} + #[derive(Debug)] pub struct SkillMatcher { embeddings: Vec<(usize, Vec)>, @@ -49,7 +67,7 @@ impl SkillMatcher { Some(Self { embeddings }) } - /// Match a user query against stored skill embeddings, returning the top-K indices + /// Match a user query against stored skill embeddings, returning the top-K scored matches /// ranked by cosine similarity. /// /// Returns an empty vec if the query embedding fails. @@ -59,7 +77,7 @@ impl SkillMatcher { query: &str, limit: usize, embed_fn: F, - ) -> Vec + ) -> Vec where F: Fn(&str) -> EmbedFuture, { @@ -76,16 +94,23 @@ impl SkillMatcher { } }; - let mut scored: Vec<(usize, f32)> = self + let mut scored: Vec = self .embeddings .iter() - .map(|(idx, emb)| (*idx, cosine_similarity(&query_vec, emb))) + .map(|(idx, emb)| ScoredMatch { + index: *idx, + score: cosine_similarity(&query_vec, emb), + }) .collect(); - scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.sort_unstable_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); scored.truncate(limit); - scored.into_iter().map(|(idx, _)| idx).collect() + scored } } @@ -110,7 +135,7 @@ impl SkillMatcherBackend { query: &str, limit: usize, embed_fn: F, - ) -> Vec + ) -> Vec where F: Fn(&str) -> EmbedFuture, { @@ -249,8 +274,9 @@ mod tests { .await; assert_eq!(matched.len(), 2); - assert_eq!(matched[0], 0); // "a" / "alpha" - assert_eq!(matched[1], 1); // "b" / "beta" + assert_eq!(matched[0].index, 0); // "a" / "alpha" + assert_eq!(matched[1].index, 1); // "b" / "beta" + assert!(matched[0].score >= matched[1].score); } #[tokio::test] @@ -271,7 +297,7 @@ mod tests { .await; assert_eq!(matched.len(), 1); - assert_eq!(matched[0], 0); + assert_eq!(matched[0].index, 0); } #[tokio::test] @@ -394,7 +420,7 @@ mod tests { .await; assert_eq!(matched.len(), 3); - assert_eq!(matched[0], 1); // "close" / "alpha" is closest to "query" + assert_eq!(matched[0].index, 1); // "close" / "alpha" is closest to "query" } #[test] @@ -444,4 +470,114 @@ mod tests { let dbg = format!("{backend:?}"); assert!(dbg.contains("InMemory")); } + + #[test] + fn scored_match_clone_and_debug() { + let sm = ScoredMatch { + index: 0, + score: 0.95, + }; + let cloned = sm.clone(); + assert_eq!(cloned.index, 0); + assert!((cloned.score - 0.95).abs() < f32::EPSILON); + let dbg = format!("{sm:?}"); + assert!(dbg.contains("ScoredMatch")); + } + + #[test] + fn intent_classification_deserialize() { + let json = r#"{"skill_name":"git","confidence":0.9,"params":{"branch":"main"}}"#; + let ic: IntentClassification = serde_json::from_str(json).unwrap(); + assert_eq!(ic.skill_name, "git"); + assert!((ic.confidence - 0.9).abs() < f32::EPSILON); + assert_eq!(ic.params.get("branch").unwrap(), "main"); + } + + #[test] + fn intent_classification_deserialize_without_params() { + let json = r#"{"skill_name":"test","confidence":0.5}"#; + let ic: IntentClassification = serde_json::from_str(json).unwrap(); + assert_eq!(ic.skill_name, "test"); + assert!(ic.params.is_empty()); + } + + #[test] + fn intent_classification_json_schema() { + let schema = schemars::schema_for!(IntentClassification); + let json = serde_json::to_string(&schema).unwrap(); + assert!(json.contains("skill_name")); + assert!(json.contains("confidence")); + } + + #[test] + fn intent_classification_rejects_missing_required_fields() { + let json = r#"{"confidence":0.5}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } + + #[test] + fn scored_match_delta_threshold_zero_disables_disambiguation() { + // With threshold = 0.0 the condition `(scores[0] - scores[1]) < threshold` + // evaluates to `delta < 0.0`. For any pair of sorted (descending) scores the + // delta is always >= 0.0, so this threshold effectively disables disambiguation. + let threshold = 0.0_f32; + + let high = ScoredMatch { + index: 0, + score: 0.90, + }; + let low = ScoredMatch { + index: 1, + score: 0.89, + }; + let delta = high.score - low.score; // 0.01 + + assert!( + delta >= 0.0, + "delta between sorted scores is always non-negative" + ); + assert!( + !(delta < threshold), + "with threshold=0.0 disambiguation must NOT be triggered" + ); + } + + #[test] + fn scored_match_delta_at_threshold_boundary() { + let threshold = 0.05_f32; + + // delta clearly above threshold => not ambiguous + let high = ScoredMatch { + index: 0, + score: 0.90, + }; + let low = ScoredMatch { + index: 1, + score: 0.80, + }; + assert!(!((high.score - low.score) < threshold)); + + // delta clearly below threshold => ambiguous + let close = ScoredMatch { + index: 2, + score: 0.89, + }; + assert!((high.score - close.score) < threshold); + } + + #[tokio::test] + async fn match_skills_returns_scores() { + let metas = vec![make_meta("a", "alpha"), make_meta("b", "beta")]; + let refs: Vec<&SkillMeta> = metas.iter().collect(); + + let matcher = SkillMatcher::new(&refs, embed_fn_mapping).await.unwrap(); + let matched = matcher + .match_skills(refs.len(), "query", 2, embed_fn_mapping) + .await; + + assert_eq!(matched.len(), 2); + assert!(matched[0].score > 0.0); + assert!(matched[0].score >= matched[1].score); + } } diff --git a/crates/zeph-skills/src/qdrant_matcher.rs b/crates/zeph-skills/src/qdrant_matcher.rs index 2d36ec28..e80f533e 100644 --- a/crates/zeph-skills/src/qdrant_matcher.rs +++ b/crates/zeph-skills/src/qdrant_matcher.rs @@ -5,7 +5,7 @@ use zeph_memory::QdrantOps; use crate::error::SkillError; use crate::loader::SkillMeta; -use crate::matcher::EmbedFuture; +use crate::matcher::{EmbedFuture, ScoredMatch}; const COLLECTION_NAME: &str = "zeph_skills"; @@ -167,14 +167,14 @@ impl QdrantSkillMatcher { } /// Search for relevant skills using Qdrant native vector search. - /// Returns indices into the provided meta slice. + /// Returns scored matches with indices into the provided meta slice. pub async fn match_skills( &self, meta: &[&SkillMeta], query: &str, limit: usize, embed_fn: F, - ) -> Vec + ) -> Vec where F: Fn(&str) -> EmbedFuture, { @@ -210,7 +210,11 @@ impl QdrantSkillMatcher { Some(Kind::StringValue(s)) => s.as_str(), _ => return None, }; - meta.iter().position(|m| m.name == name_str) + let index = meta.iter().position(|m| m.name == name_str)?; + Some(ScoredMatch { + index, + score: point.score, + }) }) .collect() } diff --git a/docs/src/getting-started/configuration.md b/docs/src/getting-started/configuration.md index ce173347..aa9d0f7a 100644 --- a/docs/src/getting-started/configuration.md +++ b/docs/src/getting-started/configuration.md @@ -76,6 +76,7 @@ max_tokens = 4096 [skills] paths = ["./skills"] max_active_skills = 5 # Top-K skills per query via embedding similarity +disambiguation_threshold = 0.05 # LLM disambiguation when top-2 score delta < threshold (0.0 = disabled) [memory] sqlite_path = "./data/zeph.db" diff --git a/docs/src/guide/skills.md b/docs/src/guide/skills.md index d905aafa..715df38e 100644 --- a/docs/src/guide/skills.md +++ b/docs/src/guide/skills.md @@ -126,6 +126,17 @@ export ZEPH_SKILLS_MAX_ACTIVE=5 Lower values reduce prompt size but may miss relevant skills. Higher values include more context but use more tokens. +### Disambiguation Threshold + +When the top two candidate skills have cosine similarity scores within `disambiguation_threshold` of each other, the agent calls the LLM with a structured prompt to clarify intent. The LLM returns a typed `IntentClassification` (skill name, confidence 0-1, extracted parameters) via `chat_typed`, and the result reorders the candidate list so the best-matching skill is injected first. + +```toml +[skills] +disambiguation_threshold = 0.05 +``` + +Set to `0.0` to disable disambiguation entirely (always use ranking order). Higher values cause disambiguation to trigger more aggressively on ambiguous queries. + ## Progressive Loading Only metadata (~100 tokens per skill) is loaded at startup for embedding and matching. Full body (<5000 tokens) is loaded lazily on first activation and cached via `OnceLock`. Resource files are loaded on demand. diff --git a/src/main.rs b/src/main.rs index c8b513a1..219b7362 100644 --- a/src/main.rs +++ b/src/main.rs @@ -275,6 +275,7 @@ async fn main() -> anyhow::Result<()> { .with_max_tool_iterations(config.agent.max_tool_iterations) .with_model_name(config.llm.model.clone()) .with_embedding_model(embed_model.clone()) + .with_disambiguation_threshold(config.skills.disambiguation_threshold) .with_skill_reload(skill_paths.clone(), reload_rx) .with_memory( memory, diff --git a/tests/integration.rs b/tests/integration.rs index fb30d090..1f092e97 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1874,6 +1874,83 @@ async fn agent_rebuild_with_skill_matcher() { assert_eq!(collected[0], "matched response"); } +// -- disambiguation reorders skill selection when scores are very close -- + +#[tokio::test] +async fn agent_disambiguation_reorders_skill_selection() { + use zeph_skills::matcher::{SkillMatcher, SkillMatcherBackend}; + + // Disambiguation response selects "second-skill"; second response is the normal chat reply. + let disambiguation_json = + r#"{"skill_name":"second-skill","confidence":0.9,"params":{}}"#.to_string(); + let mut provider_inner = + MockProvider::with_responses(vec![disambiguation_json, "ok".to_string()]); + // Identical embeddings for all texts forces a zero-delta between the two matched skills. + provider_inner.supports_embeddings = true; + provider_inner.embedding = vec![1.0, 0.0]; + let provider = AnyProvider::Mock(provider_inner); + + let outputs = Arc::new(Mutex::new(Vec::new())); + let channel = MockChannel::new(vec!["do the second thing"], outputs.clone()); + + let dir = tempfile::tempdir().unwrap(); + // Skill names must use only lowercase letters, digits, and hyphens (no underscores). + let skill_dir1 = dir.path().join("first-skill"); + std::fs::create_dir(&skill_dir1).unwrap(); + std::fs::write( + skill_dir1.join("SKILL.md"), + "---\nname: first-skill\ndescription: first skill\n---\nfirst body", + ) + .unwrap(); + let skill_dir2 = dir.path().join("second-skill"); + std::fs::create_dir(&skill_dir2).unwrap(); + std::fs::write( + skill_dir2.join("SKILL.md"), + "---\nname: second-skill\ndescription: second skill\n---\nsecond body", + ) + .unwrap(); + + let registry = SkillRegistry::load(&[dir.path().to_path_buf()]); + let all_meta = registry.all_meta(); + assert_eq!(all_meta.len(), 2, "both skills must be loaded"); + + // Pre-build InMemory matcher with constant embeddings so both skills score equally. + let embed_fn = |text: &str| -> zeph_skills::matcher::EmbedFuture { + let _ = text; + Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) + }; + let matcher = SkillMatcher::new(&all_meta, embed_fn) + .await + .expect("matcher must be built: embed_fn always succeeds"); + let backend = SkillMatcherBackend::InMemory(matcher); + + let (tx, rx) = tokio::sync::watch::channel(zeph_core::metrics::MetricsSnapshot::default()); + + let mut agent = Agent::new( + provider, + channel, + registry, + Some(backend), + 2, + MockToolExecutor, + ) + // Threshold of 1.0 guarantees disambiguation fires: any score delta < 1.0. + .with_disambiguation_threshold(1.0) + .with_metrics(tx); + + agent.run().await.unwrap(); + + let snapshot = rx.borrow().clone(); + assert!( + !snapshot.active_skills.is_empty(), + "active_skills must be populated after run" + ); + assert_eq!( + snapshot.active_skills[0], "second-skill", + "disambiguation must move second-skill to the front" + ); +} + // -- multiple commands in one session (skills + normal message) -- #[tokio::test]