From 5e251f3f09b11b7467eab1136b31dc4694b7ee0d Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 27 May 2025 15:12:18 -0700 Subject: [PATCH 1/4] limit 5 --- crates/goose/src/agents/router_tool_selector.rs | 4 ++-- crates/goose/src/agents/tool_vectordb.rs | 4 ++-- ui/desktop/src/components/settings_v2/SettingsView.tsx | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index efb0898e828b..f46644716b99 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -60,7 +60,7 @@ impl RouterToolSelector for VectorToolSelector { .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::InvalidParameters("Missing 'query' parameter".to_string()))?; - let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(20) as usize; + let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize; // Generate embedding for the query let query_embedding = self @@ -74,7 +74,7 @@ impl RouterToolSelector for VectorToolSelector { // Search for similar tools let vector_db = self.vector_db.read().await; let tools = vector_db - .search_tools(query_embedding, limit) + .search_tools(query_embedding, k) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?; diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 2a5279ed1097..6c2053858a4f 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -238,7 +238,7 @@ impl ToolVectorDB { pub async fn search_tools( &self, query_vector: Vec, - limit: usize, + k: usize, ) -> Result> { let connection = self.connection.read().await; @@ -251,7 +251,7 @@ impl ToolVectorDB { let results = table .vector_search(query_vector) .context("Failed to create vector search")? - .limit(limit) + .limit(k) .execute() .await .context("Failed to execute vector search")?; diff --git a/ui/desktop/src/components/settings_v2/SettingsView.tsx b/ui/desktop/src/components/settings_v2/SettingsView.tsx index f670cdbaee2b..b6a5bd7d3dec 100644 --- a/ui/desktop/src/components/settings_v2/SettingsView.tsx +++ b/ui/desktop/src/components/settings_v2/SettingsView.tsx @@ -4,6 +4,7 @@ import type { View, ViewOptions } from '../../App'; import ExtensionsSection from './extensions/ExtensionsSection'; import ModelsSection from './models/ModelsSection'; import { ModeSection } from './mode/ModeSection'; +import { ToolSelectionStrategySection } from './tool_selection_strategy/ToolSelectionStrategySection'; import SessionSharingSection from './sessions/SessionSharingSection'; import { ResponseStylesSection } from './response_styles/ResponseStylesSection'; import { ExtensionConfig } from '../../api'; @@ -50,6 +51,8 @@ export default function SettingsView({ {/* Response Styles */} + {/* Tool Selection Strategy */} + From af4ada39589be0a4d35e7564926c98e39ce94a34 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 27 May 2025 18:32:02 -0700 Subject: [PATCH 2/4] table creation by timestamp --- Cargo.lock | 1 + crates/goose/Cargo.toml | 1 + crates/goose/examples/agent.rs | 1 + crates/goose/src/agents/agent.rs | 43 ++++++++++--- .../goose/src/agents/router_tool_selector.rs | 28 ++++++-- crates/goose/src/agents/tool_vectordb.rs | 64 +++++++++++++------ crates/goose/src/providers/databricks.rs | 2 +- 7 files changed, 105 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d81ee21bf521..62a0f509ae39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3441,6 +3441,7 @@ dependencies = [ "chrono", "criterion", "ctor", + "dirs 5.0.1", "dotenv", "downcast-rs", "etcetera", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index dbdb23ff69d7..f1aa0d262aa4 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -61,6 +61,7 @@ once_cell = "1.20.2" etcetera = "0.8.0" rand = "0.8.5" utoipa = "4.1" +dirs = "5.0" # For Bedrock provider aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index bc3badac5372..e3f5c7f77023 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -13,6 +13,7 @@ async fn main() { let _ = dotenv(); let provider = Arc::new(DatabricksProvider::default()); + // Setup an agent with the developer extension let agent = Agent::new(); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 537340437ab2..ed724bba62a1 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -29,6 +29,7 @@ use crate::agents::router_tool_selector::{ create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, }; use crate::agents::router_tools::ROUTER_VECTOR_SEARCH_TOOL_NAME; +use crate::agents::tool_vectordb::generate_table_id; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; use mcp_core::{ @@ -192,7 +193,9 @@ impl Agent { "Frontend tool execution required".to_string(), )) } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME { + eprintln!("[DEBUG] Received tool call: {:?}", tool_call); let router_tool_selector = self.router_tool_selector.lock().await; + eprintln!("[DEBUG] Router tool selector: "); if let Some(selector) = router_tool_selector.as_ref() { selector.select_tools(tool_call.arguments.clone()).await } else { @@ -669,25 +672,49 @@ impl Agent { let router_tool_selection_strategy = config .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") .unwrap_or_else(|_| "default".to_string()); + + eprintln!("[DEBUG] Router tool selection strategy from config: {}", router_tool_selection_strategy); let strategy = match router_tool_selection_strategy.to_lowercase().as_str() { "vector" => Some(RouterToolSelectionStrategy::Vector), _ => None, }; - if let Some(strategy) = strategy { - let selector = create_tool_selector(Some(strategy), provider) - .await - .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; + eprintln!("[DEBUG] Parsed strategy: {:?}", strategy); - // Clear tools from the vector database - selector - .clear_tools() + if let Some(strategy) = strategy { + eprintln!("[DEBUG] Creating tool selector with vector strategy..."); + let table_name = generate_table_id(); + eprintln!("[DEBUG] Table name: {}", table_name); + let selector = create_tool_selector( + Some(strategy), + provider, + table_name + ) .await - .map_err(|e| anyhow!("Failed to clear tools: {}", e))?; + .map_err(|e| { + eprintln!("[DEBUG] Failed to create tool selector: {}", e); + anyhow!("Failed to create tool selector: {}", e) + })?; + + eprintln!("[DEBUG] Clearing existing tools from vector database..."); + // // Clear tools from the vector database + // selector + // .clear_tools() + // .await + // .map_err(|e| { + // eprintln!("[DEBUG] Failed to clear tools: {}", e); + // anyhow!("Failed to clear tools: {}", e) + // })?; + + eprintln!("[DEBUG] Setting router tool selector..."); *self.router_tool_selector.lock().await = Some(selector); + eprintln!("[DEBUG] Indexing platform tools..."); self.index_platform_tools().await?; + eprintln!("[DEBUG] Router tool selector initialization complete"); + } else { + eprintln!("[DEBUG] No vector strategy selected, skipping router tool selector initialization"); } Ok(()) diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index f46644716b99..5a182928a7b9 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -11,6 +11,7 @@ use crate::agents::embeddings::{create_embedding_provider, EmbeddingProviderTrai use crate::agents::tool_vectordb::ToolVectorDB; use crate::providers::base::Provider; +#[derive(Debug, Clone)] pub enum RouterToolSelectionStrategy { Vector, } @@ -37,8 +38,8 @@ pub struct VectorToolSelector { } impl VectorToolSelector { - pub async fn new(provider: Arc) -> Result { - let vector_db = ToolVectorDB::new(Some("tools".to_string())) + pub async fn new(provider: Arc, table_name: String) -> Result { + let vector_db = ToolVectorDB::new(Some(table_name)) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to create vector DB: {}", e)))?; @@ -55,28 +56,41 @@ impl VectorToolSelector { #[async_trait] impl RouterToolSelector for VectorToolSelector { async fn select_tools(&self, params: Value) -> Result, ToolError> { + eprintln!("[DEBUG] Received params: {:?}", params); + let query = params .get("query") .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::InvalidParameters("Missing 'query' parameter".to_string()))?; + eprintln!("[DEBUG] Extracted query: {}", query); + let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize; + eprintln!("[DEBUG] Using k value: {}", k); // Generate embedding for the query + eprintln!("[DEBUG] Generating embedding for query..."); let query_embedding = self .embedding_provider .embed_single(query.to_string()) .await .map_err(|e| { + eprintln!("[DEBUG] Embedding generation failed: {}", e); ToolError::ExecutionError(format!("Failed to generate query embedding: {}", e)) })?; + eprintln!("[DEBUG] Successfully generated embedding"); // Search for similar tools + eprintln!("[DEBUG] Starting vector search..."); let vector_db = self.vector_db.read().await; let tools = vector_db .search_tools(query_embedding, k) .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?; + .map_err(|e| { + eprintln!("[DEBUG] Vector search failed: {}", e); + ToolError::ExecutionError(format!("Failed to search tools: {}", e)) + })?; + eprintln!("[DEBUG] Vector search completed, found {} tools", tools.len()); // Convert tool records to Content let selected_tools: Vec = tools @@ -93,6 +107,7 @@ impl RouterToolSelector for VectorToolSelector { }) .collect(); + eprintln!("[DEBUG] Successfully converted {} tools to Content", selected_tools.len()); Ok(selected_tools) } @@ -132,7 +147,7 @@ impl RouterToolSelector for VectorToolSelector { } async fn clear_tools(&self) -> Result<(), ToolError> { - let vector_db = self.vector_db.read().await; + let vector_db = self.vector_db.write().await; vector_db .clear_tools() .await @@ -167,14 +182,15 @@ impl RouterToolSelector for VectorToolSelector { pub async fn create_tool_selector( strategy: Option, provider: Arc, + table_name: String, ) -> Result, ToolError> { match strategy { Some(RouterToolSelectionStrategy::Vector) => { - let selector = VectorToolSelector::new(provider).await?; + let selector = VectorToolSelector::new(provider, table_name).await?; Ok(Box::new(selector)) } None => { - let selector = VectorToolSelector::new(provider).await?; + let selector = VectorToolSelector::new(provider, table_name).await?; Ok(Box::new(selector)) } } diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 6c2053858a4f..20f18604643b 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result}; use arrow::array::{FixedSizeListBuilder, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; -use etcetera::BaseStrategy; +use chrono::Local; use futures::TryStreamExt; use lancedb::connect; use lancedb::connection::Connection; @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; +use dirs; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolRecord { @@ -45,6 +46,8 @@ impl ToolVectorDB { table_name: table_name.unwrap_or_else(|| "tools".to_string()), }; + eprintln!("[DEBUG] Table name: {}", tool_db.table_name); + // Initialize the table if it doesn't exist tool_db.init_table().await?; @@ -52,11 +55,18 @@ impl ToolVectorDB { } fn get_db_path() -> Result { - let data_dir = etcetera::choose_base_strategy() - .context("Failed to determine base strategy")? - .data_dir(); + let home_dir = dirs::home_dir() + .context("Failed to get home directory")? + .join(".goose") + .join("tool_db"); + + // Ensure the directory exists + if let Some(parent) = home_dir.parent() { + std::fs::create_dir_all(parent) + .context("Failed to create database directory")?; + } - Ok(data_dir.join("goose").join("tool_db")) + Ok(home_dir) } async fn init_table(&self) -> Result<()> { @@ -120,33 +130,43 @@ impl ToolVectorDB { .create_table(&self.table_name, Box::new(reader)) .execute() .await - .context("Failed to create tools table")?; + .map_err(|e| anyhow::anyhow!("Failed to create tools table '{}': {}", self.table_name, e))?; } Ok(()) } pub async fn clear_tools(&self) -> Result<()> { + eprintln!("[DEBUG] Starting clear_tools operation..."); let connection = self.connection.write().await; - - // Drop the table if it exists - let table_names = connection - .table_names() - .execute() - .await - .context("Failed to list tables")?; - - if table_names.contains(&self.table_name) { - connection - .drop_table(&self.table_name) - .await - .context("Failed to drop tools table")?; + eprintln!("[DEBUG] Acquired write lock on connection"); + + // Try to open the table first + eprintln!("[DEBUG] Attempting to open table {}", self.table_name); + match connection.open_table(&self.table_name).execute().await { + Ok(table) => { + eprintln!("[DEBUG] Successfully opened table, attempting to delete all records"); + // Delete all records instead of dropping the table + table + .delete("1=1") // This will match all records + .await + .context("Failed to delete all records")?; + eprintln!("[DEBUG] Successfully deleted all records"); + } + Err(e) => { + eprintln!("[DEBUG] Error opening table: {:?}", e); + // If table doesn't exist, that's fine - we'll create it + eprintln!("[DEBUG] Table may not exist, will create if needed"); + } } drop(connection); + eprintln!("[DEBUG] Released write lock on connection"); - // Reinitialize the table + // Ensure table exists with correct schema + eprintln!("[DEBUG] Ensuring table exists with correct schema"); self.init_table().await?; + eprintln!("[DEBUG] Successfully initialized table"); Ok(()) } @@ -324,6 +344,10 @@ impl ToolVectorDB { } } +pub fn generate_table_id() -> String { + Local::now().format("%Y%m%d_%H%M%S").to_string() +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index ca636d094ed0..ace5f5cc4afe 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,5 +1,5 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; -use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; +use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::formats::databricks::{create_request, get_usage, response_to_message}; use super::oauth; From 92ab77bdfeb160fcba30b94edf89490180ea651a Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 27 May 2025 18:36:22 -0700 Subject: [PATCH 3/4] checkpoint --- crates/goose/examples/agent.rs | 1 - crates/goose/src/agents/agent.rs | 17 +++++++++-------- crates/goose/src/agents/embeddings.rs | 4 ++-- crates/goose/src/agents/router_tool_selector.rs | 10 ++++++++-- crates/goose/src/agents/tool_vectordb.rs | 15 ++++++--------- 5 files changed, 25 insertions(+), 22 deletions(-) diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index e3f5c7f77023..bc3badac5372 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -13,7 +13,6 @@ async fn main() { let _ = dotenv(); let provider = Arc::new(DatabricksProvider::default()); - // Setup an agent with the developer extension let agent = Agent::new(); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index ed724bba62a1..ee7f0337b508 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -672,8 +672,11 @@ impl Agent { let router_tool_selection_strategy = config .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") .unwrap_or_else(|_| "default".to_string()); - - eprintln!("[DEBUG] Router tool selection strategy from config: {}", router_tool_selection_strategy); + + eprintln!( + "[DEBUG] Router tool selection strategy from config: {}", + router_tool_selection_strategy + ); let strategy = match router_tool_selection_strategy.to_lowercase().as_str() { "vector" => Some(RouterToolSelectionStrategy::Vector), @@ -686,11 +689,7 @@ impl Agent { eprintln!("[DEBUG] Creating tool selector with vector strategy..."); let table_name = generate_table_id(); eprintln!("[DEBUG] Table name: {}", table_name); - let selector = create_tool_selector( - Some(strategy), - provider, - table_name - ) + let selector = create_tool_selector(Some(strategy), provider, table_name) .await .map_err(|e| { eprintln!("[DEBUG] Failed to create tool selector: {}", e); @@ -714,7 +713,9 @@ impl Agent { self.index_platform_tools().await?; eprintln!("[DEBUG] Router tool selector initialization complete"); } else { - eprintln!("[DEBUG] No vector strategy selected, skipping router tool selector initialization"); + eprintln!( + "[DEBUG] No vector strategy selected, skipping router tool selector initialization" + ); } Ok(()) diff --git a/crates/goose/src/agents/embeddings.rs b/crates/goose/src/agents/embeddings.rs index 3e315c5ed08a..a2da242da376 100644 --- a/crates/goose/src/agents/embeddings.rs +++ b/crates/goose/src/agents/embeddings.rs @@ -349,14 +349,14 @@ mod tests { env::set_var("OPENAI_API_KEY", "test_key"); let provider = EmbeddingProvider::new(Arc::new(MockProvider)).unwrap(); assert_eq!(provider.token, "test_key"); - assert_eq!(provider.model, "mock-model"); + assert_eq!(provider.model.model_name, "mock-model"); assert_eq!(provider.base_url, "https://api.openai.com/v1"); // Test with custom configuration env::set_var("EMBEDDING_MODEL", "custom-model"); env::set_var("EMBEDDING_BASE_URL", "https://custom.api.com"); let provider = EmbeddingProvider::new(Arc::new(MockProvider)).unwrap(); - assert_eq!(provider.model, "custom-model"); + assert_eq!(provider.model.model_name, "custom-model"); assert_eq!(provider.base_url, "https://custom.api.com"); // Cleanup diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index 5a182928a7b9..b8c0e64438d5 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -90,7 +90,10 @@ impl RouterToolSelector for VectorToolSelector { eprintln!("[DEBUG] Vector search failed: {}", e); ToolError::ExecutionError(format!("Failed to search tools: {}", e)) })?; - eprintln!("[DEBUG] Vector search completed, found {} tools", tools.len()); + eprintln!( + "[DEBUG] Vector search completed, found {} tools", + tools.len() + ); // Convert tool records to Content let selected_tools: Vec = tools @@ -107,7 +110,10 @@ impl RouterToolSelector for VectorToolSelector { }) .collect(); - eprintln!("[DEBUG] Successfully converted {} tools to Content", selected_tools.len()); + eprintln!( + "[DEBUG] Successfully converted {} tools to Content", + selected_tools.len() + ); Ok(selected_tools) } diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 20f18604643b..c4a567b37cdb 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -2,6 +2,7 @@ use anyhow::{Context, Result}; use arrow::array::{FixedSizeListBuilder, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use chrono::Local; +use dirs; use futures::TryStreamExt; use lancedb::connect; use lancedb::connection::Connection; @@ -10,7 +11,6 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use dirs; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolRecord { @@ -62,8 +62,7 @@ impl ToolVectorDB { // Ensure the directory exists if let Some(parent) = home_dir.parent() { - std::fs::create_dir_all(parent) - .context("Failed to create database directory")?; + std::fs::create_dir_all(parent).context("Failed to create database directory")?; } Ok(home_dir) @@ -130,7 +129,9 @@ impl ToolVectorDB { .create_table(&self.table_name, Box::new(reader)) .execute() .await - .map_err(|e| anyhow::anyhow!("Failed to create tools table '{}': {}", self.table_name, e))?; + .map_err(|e| { + anyhow::anyhow!("Failed to create tools table '{}': {}", self.table_name, e) + })?; } Ok(()) @@ -255,11 +256,7 @@ impl ToolVectorDB { Ok(()) } - pub async fn search_tools( - &self, - query_vector: Vec, - k: usize, - ) -> Result> { + pub async fn search_tools(&self, query_vector: Vec, k: usize) -> Result> { let connection = self.connection.read().await; let table = connection From 1afbc788b9e2790ced8ab62daa2d3e7743d57523 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 27 May 2025 18:42:52 -0700 Subject: [PATCH 4/4] remove logging; add ToolSelectionStrategySection.tsx component --- Cargo.lock | 1 - crates/goose/Cargo.toml | 1 - crates/goose/src/agents/tool_vectordb.rs | 31 ++---- .../ToolSelectionStrategySection.tsx | 100 ++++++++++++++++++ 4 files changed, 106 insertions(+), 27 deletions(-) create mode 100644 ui/desktop/src/components/settings_v2/tool_selection_strategy/ToolSelectionStrategySection.tsx diff --git a/Cargo.lock b/Cargo.lock index 62a0f509ae39..d81ee21bf521 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3441,7 +3441,6 @@ dependencies = [ "chrono", "criterion", "ctor", - "dirs 5.0.1", "dotenv", "downcast-rs", "etcetera", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index f1aa0d262aa4..dbdb23ff69d7 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -61,7 +61,6 @@ once_cell = "1.20.2" etcetera = "0.8.0" rand = "0.8.5" utoipa = "4.1" -dirs = "5.0" # For Bedrock provider aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index c4a567b37cdb..f0e2397e39cd 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -2,7 +2,7 @@ use anyhow::{Context, Result}; use arrow::array::{FixedSizeListBuilder, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use chrono::Local; -use dirs; +use etcetera::base_strategy::{BaseStrategy, Xdg}; use futures::TryStreamExt; use lancedb::connect; use lancedb::connection::Connection; @@ -46,8 +46,6 @@ impl ToolVectorDB { table_name: table_name.unwrap_or_else(|| "tools".to_string()), }; - eprintln!("[DEBUG] Table name: {}", tool_db.table_name); - // Initialize the table if it doesn't exist tool_db.init_table().await?; @@ -55,17 +53,11 @@ impl ToolVectorDB { } fn get_db_path() -> Result { - let home_dir = dirs::home_dir() - .context("Failed to get home directory")? - .join(".goose") - .join("tool_db"); - - // Ensure the directory exists - if let Some(parent) = home_dir.parent() { - std::fs::create_dir_all(parent).context("Failed to create database directory")?; - } + let data_dir = Xdg::new() + .context("Failed to determine base strategy")? + .data_dir(); - Ok(home_dir) + Ok(data_dir.join("goose").join("tool_db")) } async fn init_table(&self) -> Result<()> { @@ -138,36 +130,26 @@ impl ToolVectorDB { } pub async fn clear_tools(&self) -> Result<()> { - eprintln!("[DEBUG] Starting clear_tools operation..."); let connection = self.connection.write().await; - eprintln!("[DEBUG] Acquired write lock on connection"); // Try to open the table first - eprintln!("[DEBUG] Attempting to open table {}", self.table_name); match connection.open_table(&self.table_name).execute().await { Ok(table) => { - eprintln!("[DEBUG] Successfully opened table, attempting to delete all records"); // Delete all records instead of dropping the table table .delete("1=1") // This will match all records .await .context("Failed to delete all records")?; - eprintln!("[DEBUG] Successfully deleted all records"); } - Err(e) => { - eprintln!("[DEBUG] Error opening table: {:?}", e); + Err(_) => { // If table doesn't exist, that's fine - we'll create it - eprintln!("[DEBUG] Table may not exist, will create if needed"); } } drop(connection); - eprintln!("[DEBUG] Released write lock on connection"); // Ensure table exists with correct schema - eprintln!("[DEBUG] Ensuring table exists with correct schema"); self.init_table().await?; - eprintln!("[DEBUG] Successfully initialized table"); Ok(()) } @@ -309,7 +291,6 @@ impl ToolVectorDB { for i in 0..batch.num_rows() { let tool_name = tool_names.value(i).to_string(); let distance = distances.value(i); - eprintln!("Tool: {}, Distance Score: {}", tool_name, distance); tools.push(ToolRecord { tool_name, diff --git a/ui/desktop/src/components/settings_v2/tool_selection_strategy/ToolSelectionStrategySection.tsx b/ui/desktop/src/components/settings_v2/tool_selection_strategy/ToolSelectionStrategySection.tsx new file mode 100644 index 000000000000..7a744ecde3d3 --- /dev/null +++ b/ui/desktop/src/components/settings_v2/tool_selection_strategy/ToolSelectionStrategySection.tsx @@ -0,0 +1,100 @@ +import { useEffect, useState, useCallback } from 'react'; +import { View, ViewOptions } from '../../../App'; +import { useConfig } from '../../ConfigContext'; + +interface ToolSelectionStrategySectionProps { + setView: (view: View, viewOptions?: ViewOptions) => void; +} + +export const all_tool_selection_strategies = [ + { + key: 'default', + label: 'Default', + description: 'Loads all tools from enabled extension', + }, + { + key: 'vector', + label: 'Vector', + description: + 'Filter tools based on vector-based similarity. Recommended when many extensions are enabled.', + }, +]; + +export const ToolSelectionStrategySection = ({ + setView: _setView, +}: ToolSelectionStrategySectionProps) => { + const [currentStrategy, setCurrentStrategy] = useState('default'); + const { read, upsert } = useConfig(); + + const handleStrategyChange = async (newStrategy: string) => { + try { + await upsert('GOOSE_ROUTER_TOOL_SELECTION_STRATEGY', newStrategy, false); + setCurrentStrategy(newStrategy); + } catch (error) { + console.error('Error updating tool selection strategy:', error); + throw new Error(`Failed to store new tool selection strategy: ${newStrategy}`); + } + }; + + const fetchCurrentStrategy = useCallback(async () => { + try { + const strategy = (await read('GOOSE_ROUTER_TOOL_SELECTION_STRATEGY', false)) as string; + if (strategy) { + setCurrentStrategy(strategy); + } + } catch (error) { + console.error('Error fetching current tool selection strategy:', error); + } + }, [read]); + + useEffect(() => { + fetchCurrentStrategy(); + }, [fetchCurrentStrategy]); + + return ( +
+
+

Tool Selection Strategy

+
+
+

+ Configure how Goose selects tools for your requests +

+
+ {all_tool_selection_strategies.map((strategy) => ( +
+
handleStrategyChange(strategy.key)} + > +
+
+

{strategy.label}

+

{strategy.description}

+
+
+ +
+ handleStrategyChange(strategy.key)} + className="peer sr-only" + /> +
+
+
+
+ ))} +
+
+
+ ); +};