Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,032 changes: 85 additions & 1,947 deletions Cargo.lock

Large diffs are not rendered by default.

41 changes: 14 additions & 27 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
.item(
"goose_router_strategy",
"Router Tool Selection Strategy",
"Configure the strategy for selecting tools to use",
"Experimental: configure a strategy for auto selecting tools to use",
)
.item(
"tool_permission",
Expand Down Expand Up @@ -1300,40 +1300,27 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box<dyn Error>> {
pub fn configure_goose_router_strategy_dialog() -> Result<(), Box<dyn Error>> {
let config = Config::global();

// Check if GOOSE_ROUTER_STRATEGY is set as an environment variable
if std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY").is_ok() {
let _ = cliclack::log::info("Notice: GOOSE_ROUTER_TOOL_SELECTION_STRATEGY environment variable is set. Configuration will override this.");
}

let strategy = cliclack::select("Which router strategy would you like to use?")
let enable_router = cliclack::select("Would you like to enable smart tool routing?")
.item(
"vector",
"Vector Strategy",
"Use vector-based similarity to select tools",
"true",
"Enable Router",
"Use LLM-based intelligence to select tools",
)
.item(
"default",
"Default Strategy",
"false",
"Disable Router",
"Use the default tool selection strategy",
)
.interact()?;

match strategy {
"vector" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("vector".to_string()),
)?;
cliclack::outro(
"Set to Vector Strategy - using vector-based similarity for tool selection",
)?;
match enable_router {
"true" => {
config.set_param("GOOSE_ENABLE_ROUTER", Value::String("true".to_string()))?;
cliclack::outro("Router enabled - using LLM-based intelligence for tool selection")?;
}
"default" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("default".to_string()),
)?;
cliclack::outro("Set to Default Strategy - using default tool selection")?;
"false" => {
config.set_param("GOOSE_ENABLE_ROUTER", Value::String("false".to_string()))?;
cliclack::outro("Router disabled - using default tool selection")?;
}
_ => unreachable!(),
};
Expand Down
2 changes: 0 additions & 2 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ ahash = "0.8"
tokio-util = "0.7.15"
unicode-normalization = "0.1"

# Vector database for tool selection
lancedb = "0.13"
arrow = "52.2"
oauth2 = "5.0.0"

Expand Down
82 changes: 37 additions & 45 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ use crate::agents::recipe_tools::dynamic_task_tools::{
create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX,
};
use crate::agents::retry::{RetryManager, RetryResult};
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME};
use crate::agents::router_tools::ROUTER_LLM_SEARCH_TOOL_NAME;
use crate::agents::sub_recipe_manager::SubRecipeManager;
use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{
self, SUBAGENT_EXECUTE_TASK_TOOL_NAME,
Expand Down Expand Up @@ -530,9 +529,7 @@ impl Agent {
"Updated ({} chars)",
char_count
))]))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME
|| tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME
{
} else if tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME {
match self
.tool_route_manager
.dispatch_route_search_tool(tool_call.arguments)
Expand Down Expand Up @@ -575,8 +572,8 @@ impl Agent {
extension_name: String,
request_id: String,
) -> (String, Result<Vec<Content>, ErrorData>) {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if let Some(selector) = selector {
let selector_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.read().await;
Expand All @@ -593,7 +590,7 @@ impl Agent {
request_id,
Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!("Failed to update vector index: {}", e),
format!("Failed to update LLM index: {}", e),
None,
)),
);
Expand Down Expand Up @@ -653,31 +650,29 @@ impl Agent {
.map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None));

drop(extension_manager);
// Update vector index if operation was successful and vector routing is enabled
if result.is_ok() {
// Update LLM index if operation was successful and LLM routing is functional
if result.is_ok() && self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let vector_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.read().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
vector_action,
)
.await
{
return (
request_id,
Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!("Failed to update vector index: {}", e),
None,
)),
);
}
if let Some(selector) = selector {
let llm_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.read().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
llm_action,
)
.await
{
return (
request_id,
Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!("Failed to update LLM index: {}", e),
None,
)),
);
}
}
}
Expand Down Expand Up @@ -718,9 +713,9 @@ impl Agent {
}
}

// If vector tool selection is enabled, index the tools
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
// If LLM tool selection is functional, index the tools
if self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if let Some(selector) = selector {
let extension_manager = self.extension_manager.read().await;
let selector = Arc::new(selector);
Expand Down Expand Up @@ -787,12 +782,9 @@ impl Agent {
prefixed_tools
}

pub async fn list_tools_for_router(
&self,
strategy: Option<RouterToolSelectionStrategy>,
) -> Vec<Tool> {
pub async fn list_tools_for_router(&self) -> Vec<Tool> {
self.tool_route_manager
.list_tools_for_router(strategy, &self.extension_manager)
.list_tools_for_router(&self.extension_manager)
.await
}

Expand All @@ -801,9 +793,9 @@ impl Agent {
extension_manager.remove_extension(name).await?;
drop(extension_manager);

// If vector tool selection is enabled, remove tools from the index
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
// If LLM tool selection is functional, remove tools from the index
if self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if let Some(selector) = selector {
let extension_manager = self.extension_manager.read().await;
ToolRouterIndexManager::update_extension_tools(
Expand Down Expand Up @@ -1350,7 +1342,7 @@ impl Agent {
self.frontend_instructions.lock().await.clone(),
extension_manager.suggest_disable_extensions_prompt().await,
Some(model_name),
None,
false,
);

let recipe_prompt = prompt_manager.get_recipe_prompt().await;
Expand Down Expand Up @@ -1509,7 +1501,7 @@ mod tests {

let prompt_manager = agent.prompt_manager.lock().await;
let system_prompt =
prompt_manager.build_system_prompt(vec![], None, Value::Null, None, None);
prompt_manager.build_system_prompt(vec![], None, Value::Null, None, false);

let final_output_tool_ref = agent.final_output_tool.lock().await;
let final_output_tool_system_prompt =
Expand Down
1 change: 0 additions & 1 deletion crates/goose/src/agents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub mod todo_tools;
mod tool_execution;
mod tool_route_manager;
mod tool_router_index_manager;
pub(crate) mod tool_vectordb;
pub mod types;

pub use agent::{Agent, AgentEvent};
Expand Down
34 changes: 12 additions & 22 deletions crates/goose/src/agents/prompt_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use serde_json::Value;
use std::collections::HashMap;

use crate::agents::extension::ExtensionInfo;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::agents::router_tools::{llm_search_tool_prompt, vector_search_tool_prompt};
use crate::agents::router_tools::llm_search_tool_prompt;
use crate::providers::base::get_current_model;
use crate::{config::Config, prompt_template, utils::sanitize_unicode_tags};

Expand Down Expand Up @@ -69,7 +68,7 @@ impl PromptManager {
frontend_instructions: Option<String>,
suggest_disable_extensions_prompt: Value,
model_name: Option<&str>,
tool_selection_strategy: Option<RouterToolSelectionStrategy>,
router_enabled: bool,
) -> String {
let mut context: HashMap<&str, Value> = HashMap::new();
let mut extensions_info = extensions_info.clone();
Expand All @@ -96,20 +95,11 @@ impl PromptManager {
serde_json::to_value(sanitized_extensions_info).unwrap(),
);

match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Vector) => {
context.insert(
"tool_selection_strategy",
Value::String(vector_search_tool_prompt()),
);
}
Some(RouterToolSelectionStrategy::Llm) => {
context.insert(
"tool_selection_strategy",
Value::String(llm_search_tool_prompt()),
);
}
None => {}
if router_enabled {
context.insert(
"tool_selection_strategy",
Value::String(llm_search_tool_prompt()),
);
}

context.insert(
Expand Down Expand Up @@ -246,7 +236,7 @@ mod tests {
manager.set_system_prompt_override(malicious_override.to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, false);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
Expand All @@ -262,7 +252,7 @@ mod tests {
manager.add_system_prompt_extra(malicious_extra.to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, false);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
Expand All @@ -279,7 +269,7 @@ mod tests {
manager.add_system_prompt_extra("Third\u{E0043}instruction".to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, false);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
Expand All @@ -296,7 +286,7 @@ mod tests {
manager.add_system_prompt_extra(legitimate_unicode.to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, false);

assert!(result.contains("世界"));
assert!(result.contains("🌍"));
Expand All @@ -318,7 +308,7 @@ mod tests {
None,
Value::String("".to_string()),
None,
None,
false,
);

assert!(!result.contains('\u{E0041}'));
Expand Down
28 changes: 10 additions & 18 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use async_stream::try_stream;
use futures::stream::StreamExt;

use super::super::agents::Agent;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::conversation::message::{Message, MessageContent, ToolRequest};
use crate::conversation::Conversation;
use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
Expand Down Expand Up @@ -35,24 +34,17 @@ async fn toolshim_postprocess(
impl Agent {
/// Prepares tools and system prompt for a provider request
pub async fn prepare_tools_and_prompt(&self) -> anyhow::Result<(Vec<Tool>, Vec<Tool>, String)> {
// Get tool selection strategy from config
let tool_selection_strategy = self
.tool_route_manager
.get_router_tool_selection_strategy()
.await;
// Get router enabled status
let router_enabled = self.tool_route_manager.is_router_enabled().await;

// Get tools from extension manager
let mut tools = match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Vector) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Vector))
.await
}
Some(RouterToolSelectionStrategy::Llm) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Llm))
.await
}
_ => self.list_tools(None).await,
};
let mut tools = self.list_tools_for_router().await;

// If router is disabled and no tools were returned, fall back to regular tools
if !router_enabled && tools.is_empty() {
tools = self.list_tools(None).await;
}

// Add frontend tools
let frontend_tools = self.frontend_tools.lock().await;
for frontend_tool in frontend_tools.values() {
Expand All @@ -74,7 +66,7 @@ impl Agent {
self.frontend_instructions.lock().await.clone(),
extension_manager.suggest_disable_extensions_prompt().await,
Some(model_name),
tool_selection_strategy,
router_enabled,
);

// Handle toolshim if enabled
Expand Down
Loading
Loading