Skip to content
142 changes: 138 additions & 4 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};

use async_trait::async_trait;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
StandardTokenResponse, TokenResponse, TokenUrl,
basic::{BasicClient, BasicTokenType},
Expand Down Expand Up @@ -548,6 +548,23 @@ impl AuthorizationManager {
client_builder = client_builder.set_client_secret(ClientSecret::new(secret));
}

let uses_secret_post = metadata
.additional_fields
.get("token_endpoint_auth_methods_supported")
.and_then(|v| v.as_array())
.map(|arr| {
let has_basic = arr
.iter()
.any(|m| m.as_str() == Some("client_secret_basic"));
let has_post = arr.iter().any(|m| m.as_str() == Some("client_secret_post"));
has_post && !has_basic
})
.unwrap_or(false);

if uses_secret_post {
client_builder = client_builder.set_auth_type(AuthType::RequestBody);
}

self.oauth_client = Some(client_builder);
Ok(())
}
Expand Down Expand Up @@ -1770,14 +1787,14 @@ impl OAuthState {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use oauth2::{CsrfToken, PkceCodeVerifier};
use oauth2::{AuthType, CsrfToken, PkceCodeVerifier};
use url::Url;

use super::{
AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore,
ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url,
OAuthClientConfig, ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url,
};

// -- url helpers --
Expand Down Expand Up @@ -2263,6 +2280,123 @@ mod tests {
manager.set_state_store(TrackingStateStore::default());
}

/// Helper: create an AuthorizationManager with minimal metadata so
/// `configure_client` can be exercised without a live server.
async fn manager_with_metadata(
metadata_override: Option<AuthorizationMetadata>,
) -> AuthorizationManager {
let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap();
mgr.set_metadata(metadata_override.unwrap_or(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
..Default::default()
}));
mgr
}

fn test_client_config() -> OAuthClientConfig {
OAuthClientConfig {
client_id: "my-client".to_string(),
client_secret: Some("my-secret".to_string()),
scopes: vec![],
redirect_uri: "http://localhost/callback".to_string(),
}
}

#[tokio::test]
async fn test_configure_client_uses_client_secret_post_from_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::RequestBody
));
}

#[tokio::test]
async fn test_configure_client_defaults_to_basic_auth() {
let mut mgr = manager_with_metadata(None).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}

#[tokio::test]
async fn test_configure_client_with_explicit_basic_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}

#[tokio::test]
async fn test_configure_client_ignores_unsupported_auth_methods_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["private_key_jwt"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
// Unsupported method should fall through to default (basic auth)
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}

#[tokio::test]
async fn test_configure_client_prefers_basic_when_both_methods_supported() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post", "client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
// -- metadata deserialization --

#[test]
Expand Down