Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use chrono::{DateTime, Utc};
use futures::stream::{FuturesUnordered, StreamExt};
use futures::{future, FutureExt};
use rmcp::service::ClientInitializeError;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::streamable_http_client::{
AuthRequiredError, StreamableHttpClientTransportConfig, StreamableHttpError,
};
use rmcp::transport::{
ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess,
ConfigureCommandExt, DynamicTransportError, SseClientTransport, StreamableHttpClientTransport,
TokioChildProcess,
};
use std::collections::HashMap;
use std::process::Stdio;
Expand Down Expand Up @@ -205,6 +208,28 @@ async fn child_process_client(
}
}

fn extract_auth_error(
res: &Result<McpClient, ClientInitializeError>,
) -> Option<&AuthRequiredError> {
match res {
Ok(_) => None,
Err(err) => match err {
ClientInitializeError::TransportError {
error: DynamicTransportError { error, .. },
..
} => error
.downcast_ref::<StreamableHttpError<reqwest::Error>>()
.and_then(|auth_error| match auth_error {
StreamableHttpError::AuthRequired(auth_required_error) => {
Some(auth_required_error)
}
_ => None,
}),
_ => None,
},
}
}

impl ExtensionManager {
pub fn new() -> Self {
Self {
Expand Down Expand Up @@ -340,15 +365,10 @@ impl ExtensionManager {
),
)
.await;
let client = if let Err(e) = client_res {
// make an attempt at oauth, but failing that, return the original error,
// because this might not have been an auth error at all.
// TODO: when rmcp supports it, we should trigger this flow on 401s with
// WWW-Authenticate headers, not just any init error
let am = match oauth_flow(uri, name).await {
Ok(am) => am,
Err(_) => return Err(e.into()),
};
let client = if let Some(_auth_error) = extract_auth_error(&client_res) {
let am = oauth_flow(uri, name)
.await
.map_err(|_| ExtensionError::SetupError("auth error".to_string()))?;
let client = AuthClient::new(reqwest::Client::default(), am);
let transport = StreamableHttpClientTransport::with_client(
client,
Expand Down
Loading