Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,15 +691,15 @@ pub async fn configure_provider_dialog() -> anyhow::Result<bool> {
};
spin.stop(style("Model fetch complete").green());

// Select a model: on fetch error show styled error and abort; if Some(models), show list; if None, free-text input
// Select a model: on fetch error show styled error and abort; if models available, show list; otherwise free-text input
let model: String = match models_res {
Err(e) => {
// Provider hook error
cliclack::outro(style(e.to_string()).on_red().white())?;
return Ok(false);
}
Ok(Some(models)) => select_model_from_list(&models, provider_meta)?,
Ok(None) => {
Ok(models) if !models.is_empty() => select_model_from_list(&models, provider_meta)?,
Ok(_) => {
let default_model =
std::env::var("GOOSE_MODEL").unwrap_or(provider_meta.default_model.clone());
cliclack::input("Enter a model from that provider:")
Expand Down
16 changes: 1 addition & 15 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,6 @@ pub async fn providers() -> Result<Json<Vec<ProviderDetails>>, ErrorResponse> {
pub async fn get_provider_models(
Path(name): Path<String>,
) -> Result<Json<Vec<String>>, ErrorResponse> {
let loaded_provider = goose::config::declarative_providers::load_provider(name.as_str()).ok();
// TODO(Douwe): support a get models url for custom providers
if let Some(loaded_provider) = loaded_provider {
return Ok(Json(
loaded_provider
.config
.models
.into_iter()
.map(|m| m.name)
.collect::<Vec<_>>(),
));
}

let all = get_providers().await.into_iter().collect::<Vec<_>>();
let Some((metadata, provider_type)) = all.into_iter().find(|(m, _)| m.name == name) else {
return Err(ErrorResponse::bad_request(format!(
Expand All @@ -405,8 +392,7 @@ pub async fn get_provider_models(
let models_result = provider.fetch_recommended_models().await;

match models_result {
Ok(Some(models)) => Ok(Json(models)),
Ok(None) => Ok(Json(Vec::new())),
Ok(models) => Ok(Json(models)),
Err(provider_error) => Err(provider_error.into()),
}
}
Expand Down
5 changes: 4 additions & 1 deletion crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ async fn enhance_model_error(error: ProviderError, provider: &Arc<dyn Provider>)
return error;
}

let Ok(Some(models)) = provider.fetch_recommended_models().await else {
let Ok(models) = provider.fetch_recommended_models().await else {
return error;
};
if models.is_empty() {
return error;
}

ProviderError::RequestFailed(format!(
"{}. Available models for this provider: {}",
Expand Down
13 changes: 7 additions & 6 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl Provider for AnthropicProvider {
Ok((message, provider_usage))
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
Expand All @@ -267,17 +267,18 @@ impl Provider for AnthropicProvider {
}

let json = response.payload.unwrap_or_default();
let arr = match json.get("data").and_then(|v| v.as_array()) {
Some(arr) => arr,
None => return Ok(None),
};
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(Some(models))
Ok(models)
}

async fn stream(
Expand Down
4 changes: 3 additions & 1 deletion crates/goose/src/providers/auto_detect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec<
})
.await
{
Ok(Some(models)) => Some((provider_name.to_string(), models)),
Ok(models) if !models.is_empty() => {
Some((provider_name.to_string(), models))
}
_ => None,
}
}
Expand Down
15 changes: 6 additions & 9 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,13 @@ pub trait Provider: Send + Sync {
RetryConfig::default()
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
Ok(None)
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(vec![])
}

/// Fetch models filtered by canonical registry and usability
async fn fetch_recommended_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
let all_models = match self.fetch_supported_models().await? {
Some(models) => models,
None => return Ok(None),
};
async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> {
let all_models = self.fetch_supported_models().await?;

let registry = CanonicalModelRegistry::bundled().map_err(|e| {
ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e))
Expand Down Expand Up @@ -501,9 +498,9 @@ pub trait Provider: Send + Sync {
.collect();

if recommended_models.is_empty() {
Ok(Some(all_models))
Ok(all_models)
} else {
Ok(Some(recommended_models))
Ok(recommended_models)
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/goose/src/providers/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ impl Provider for BedrockProvider {
self.model.clone()
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect())
}

#[tracing::instrument(
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,14 +493,10 @@ async fn check_provider(
};

let fetched_models = match provider.fetch_supported_models().await {
Ok(Some(models)) => {
Ok(models) => {
println!(" ✓ Fetched {} models", models.len());
models
}
Ok(None) => {
println!(" ⚠ Provider does not support model listing");
Vec::new()
}
Err(e) => {
println!(" ⚠ Failed to fetch models: {}", e);
println!(" This is expected if credentials are not configured.");
Expand All @@ -509,11 +505,10 @@ async fn check_provider(
};

let recommended_models = match provider.fetch_recommended_models().await {
Ok(Some(models)) => {
Ok(models) => {
println!(" ✓ Found {} recommended models", models.len());
models
}
Ok(None) => Vec::new(),
Err(e) => {
println!(" ⚠ Failed to fetch recommended models: {}", e);
Vec::new()
Expand Down
12 changes: 5 additions & 7 deletions crates/goose/src/providers/chatgpt_codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,13 +985,11 @@ impl Provider for ChatGptCodexProvider {
Ok(())
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
Ok(Some(
CHATGPT_CODEX_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect(),
))
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(CHATGPT_CODEX_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect())
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/goose/src/providers/claude_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ impl Provider for ClaudeCodeProvider {
self.model.clone()
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(CLAUDE_CODE_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect())
}

#[tracing::instrument(
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
Expand Down
6 changes: 2 additions & 4 deletions crates/goose/src/providers/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,8 @@ impl Provider for CodexProvider {
))
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
Ok(Some(
CODEX_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(),
))
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(CODEX_KNOWN_MODELS.iter().map(|s| s.to_string()).collect())
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/goose/src/providers/cursor_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ impl Provider for CursorAgentProvider {
self.model.clone()
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(CURSOR_AGENT_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect())
}

#[tracing::instrument(
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
Expand Down
61 changes: 22 additions & 39 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,51 +391,38 @@ impl Provider for DatabricksProvider {
.map_err(|e| ProviderError::ExecutionError(e.to_string()))
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
let response = match self
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let response = self
.api_client
.request(None, "api/2.0/serving-endpoints")
.response_get()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Failed to fetch Databricks models: {}", e);
return Ok(None);
}
};
.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to fetch Databricks models: {}", e))
})?;

if !response.status().is_success() {
let status = response.status();
if let Ok(error_text) = response.text().await {
tracing::warn!(
"Failed to fetch Databricks models: {} - {}",
status,
error_text
);
} else {
tracing::warn!("Failed to fetch Databricks models: {}", status);
}
return Ok(None);
let detail = response.text().await.unwrap_or_default();
return Err(ProviderError::RequestFailed(format!(
"Failed to fetch Databricks models: {} {}",
status, detail
)));
}

let json: Value = match response.json().await {
Ok(json) => json,
Err(e) => {
tracing::warn!("Failed to parse Databricks API response: {}", e);
return Ok(None);
}
};
let json: Value = response.json().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse Databricks API response: {}", e))
})?;

let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
Some(endpoints) => endpoints,
None => {
tracing::warn!(
let endpoints = json
.get("endpoints")
.and_then(|v| v.as_array())
.ok_or_else(|| {
ProviderError::RequestFailed(
"Unexpected response format from Databricks API: missing 'endpoints' array"
);
return Ok(None);
}
};
.to_string(),
)
})?;

let models: Vec<String> = endpoints
.iter()
Expand All @@ -447,11 +434,7 @@ impl Provider for DatabricksProvider {
})
.collect();

if models.is_empty() {
Ok(None)
} else {
Ok(Some(models))
}
Ok(models)
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/goose/src/providers/gcpvertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,10 @@ impl Provider for GcpVertexAIProvider {
}))
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let models: Vec<String> = KNOWN_MODELS.iter().map(|s| s.to_string()).collect();
let filtered = self.filter_by_org_policy(models).await;
Ok(Some(filtered))
Ok(filtered)
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/goose/src/providers/gemini_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ impl Provider for GeminiCliProvider {
self.model.clone()
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
Ok(GEMINI_CLI_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect())
}

#[tracing::instrument(
skip(self, _model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
Expand Down
13 changes: 7 additions & 6 deletions crates/goose/src/providers/githubcopilot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ impl Provider for GithubCopilotProvider {
stream_openai_compat(response, log)
}

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let (endpoint, token) = self.get_api_info().await?;
let url = format!("{}/models", endpoint);

Expand All @@ -515,10 +515,11 @@ impl Provider for GithubCopilotProvider {

let json: serde_json::Value = response.json().await?;

let arr = match json.get("data").and_then(|v| v.as_array()) {
Some(arr) => arr,
None => return Ok(None),
};
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in GitHub Copilot models response".to_string(),
)
})?;
let mut models: Vec<String> = arr
.iter()
.filter_map(|m| {
Expand All @@ -532,7 +533,7 @@ impl Provider for GithubCopilotProvider {
})
.collect();
models.sort();
Ok(Some(models))
Ok(models)
}

async fn configure_oauth(&self) -> Result<(), ProviderError> {
Expand Down
Loading
Loading