Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ unicode-normalization = "0.1"
# Vector database for tool selection
lancedb = "0.13"
arrow = "52.2"
oauth2 = "5.0.0"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }
Expand Down
4 changes: 3 additions & 1 deletion crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ impl ExtensionManager {
.await;
let client = if let Err(e) = client_res {
// make an attempt at oauth, but failing that, return the original error,
// because this might not have been an auth error at all
// because this might not have been an auth error at all.
// TODO: when rmcp supports it, we should trigger this flow on 401s with
// WWW-Authenticate headers, not just any init error
let am = match oauth_flow(uri, name).await {
Ok(am) => am,
Err(_) => return Err(e.into()),
Expand Down
26 changes: 23 additions & 3 deletions crates/goose/src/oauth.rs → crates/goose/src/oauth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ use serde::Deserialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{oneshot, Mutex};
use tracing::warn;

use crate::oauth::persist::{clear_credentials, load_cached_state, save_credentials};

mod persist;

const CALLBACK_TEMPLATE: &str = include_str!("oauth_callback.html");

Expand All @@ -28,6 +33,18 @@ pub async fn oauth_flow(
mcp_server_url: &String,
name: &String,
) -> Result<AuthorizationManager, anyhow::Error> {
if let Ok(oauth_state) = load_cached_state(mcp_server_url, name).await {
if let Some(authorization_manager) = oauth_state.into_authorization_manager() {
if authorization_manager.refresh_token().await.is_ok() {
return Ok(authorization_manager);
}
}

if let Err(e) = clear_credentials(name) {
warn!("error clearing bad credentials: {}", e);
}
}

let (code_sender, code_receiver) = oneshot::channel::<String>();
let app_state = AppState {
code_receiver: Arc::new(Mutex::new(Some(code_sender))),
Expand All @@ -52,7 +69,6 @@ pub async fn oauth_flow(
let used_addr = listener.local_addr()?;
tokio::spawn(async move {
let result = axum::serve(listener, app).await;

if let Err(e) = result {
eprintln!("Callback server error: {}", e);
}
Expand All @@ -73,9 +89,13 @@ pub async fn oauth_flow(
let auth_code = code_receiver.await?;
oauth_state.handle_callback(&auth_code).await?;

let am = oauth_state
if let Err(e) = save_credentials(name, &oauth_state).await {
warn!("Failed to save credentials: {}", e);
}

let auth_manager = oauth_state
.into_authorization_manager()
.ok_or_else(|| anyhow::anyhow!("Failed to get authorization manager"))?;

Ok(am)
Ok(auth_manager)
}
72 changes: 72 additions & 0 deletions crates/goose/src/oauth/persist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse};
use reqwest::IntoUrl;
use rmcp::transport::{auth::OAuthState, AuthError};
use serde::{Deserialize, Serialize};

use crate::config::Config;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableCredentials {
pub client_id: String,
pub token_response: Option<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>>,
}

fn secret_key(name: &str) -> String {
format!("oauth_creds_{name}")
}

pub async fn save_credentials(
name: &str,
oauth_state: &OAuthState,
) -> Result<(), Box<dyn std::error::Error>> {
let config = Config::global();
let (client_id, token_response) = oauth_state.get_credentials().await?;

let credentials = SerializableCredentials {
client_id,
token_response,
};

let value = serde_json::to_value(&credentials)?;
let key = secret_key(name);
config.set_secret(&key, value)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how long are the credentials typically valid? maybe I am missing it but shouldn't there be a ttl for the data we store via keyring?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah they'll have an expiry, but that should be handled by the oauth mechanism anyway -- it might need to be refreshed after getting loaded but that should be fine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested it out with an expired credential and indeed there was a bug! Added a call to refresh on load.


Ok(())
}

async fn load_credentials(
name: &str,
) -> Result<SerializableCredentials, Box<dyn std::error::Error>> {
let config = Config::global();
let key = secret_key(name);
let credentials: SerializableCredentials = config.get_secret(&key)?;

Ok(credentials)
}

pub fn clear_credentials(name: &str) -> Result<(), Box<dyn std::error::Error>> {
let config = Config::global();

Ok(config.delete_secret(&secret_key(name))?)
}

pub async fn load_cached_state<U: IntoUrl>(
base_url: U,
name: &str,
) -> Result<OAuthState, AuthError> {
let credentials = load_credentials(name)
.await
.map_err(|e| AuthError::InternalError(format!("Failed to load credentials: {}", e)))?;

if let Some(token_response) = credentials.token_response {
let mut oauth_state = OAuthState::new(base_url, None).await?;
oauth_state
.set_credentials(&credentials.client_id, token_response)
.await?;
Ok(oauth_state)
} else {
Err(AuthError::InternalError(
"No token response in cached credentials".to_string(),
))
}
}
Loading