From 77923fbe0e1507d6436489a4d441b6251c76c1be Mon Sep 17 00:00:00 2001 From: Naiker Date: Sun, 9 Jun 2024 19:01:36 +0100 Subject: [PATCH] fix test and clippy for lib --- examples/demo/src/initializers/oauth2.rs | 4 +- examples/demo/src/models/o_auth2_sessions.rs | 3 +- examples/demo/tests/requests/oauth2.rs | 10 +- src/config.rs | 20 ++- src/controllers/middleware/mod.rs | 4 +- .../{private_cookie_jar.rs => private.rs} | 34 ++--- src/controllers/oauth2.rs | 38 ++++-- src/error.rs | 24 ++-- src/grants/authorization_code.rs | 127 +++++++++--------- src/lib.rs | 24 ++-- .../m20240101_000000_oauth2_sessions.rs | 5 +- src/oauth2_grant.rs | 4 +- 12 files changed, 159 insertions(+), 138 deletions(-) rename src/controllers/middleware/{private_cookie_jar.rs => private.rs} (94%) diff --git a/examples/demo/src/initializers/oauth2.rs b/examples/demo/src/initializers/oauth2.rs index 0497911..ae4606f 100644 --- a/examples/demo/src/initializers/oauth2.rs +++ b/examples/demo/src/initializers/oauth2.rs @@ -1,5 +1,5 @@ use axum::{async_trait, Extension, Router as AxumRouter}; -use loco_oauth2::{config::OAuth2Config, OAuth2ClientStore}; +use loco_oauth2::{config::Config, OAuth2ClientStore}; use loco_rs::prelude::*; pub struct OAuth2StoreInitializer; @@ -19,7 +19,7 @@ impl Initializer for OAuth2StoreInitializer { .get("oauth2") .ok_or(Error::Message("oauth2 config not found".to_string()))? .clone(); - let oauth2_config: OAuth2Config = oauth2_config_value.try_into().map_err(|e| { + let oauth2_config: Config = oauth2_config_value.try_into().map_err(|e| { tracing::error!(error = ?e, "could not convert oauth2 config"); Error::Message("could not convert oauth2 config".to_string()) })?; diff --git a/examples/demo/src/models/o_auth2_sessions.rs b/examples/demo/src/models/o_auth2_sessions.rs index 82bb2ea..285ec7c 100644 --- a/examples/demo/src/models/o_auth2_sessions.rs +++ b/examples/demo/src/models/o_auth2_sessions.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; use chrono::Local; use loco_oauth2::{ - basic::BasicTokenResponse, models::oauth2_sessions::OAuth2SessionsTrait, TokenResponse, + base_oauth2::basic::BasicTokenResponse, base_oauth2::TokenResponse, + models::oauth2_sessions::OAuth2SessionsTrait, }; use loco_rs::model::{ModelError, ModelResult}; use sea_orm::{entity::prelude::*, ActiveValue, TransactionTrait}; diff --git a/examples/demo/tests/requests/oauth2.rs b/examples/demo/tests/requests/oauth2.rs index ee31ddd..acdf952 100644 --- a/examples/demo/tests/requests/oauth2.rs +++ b/examples/demo/tests/requests/oauth2.rs @@ -165,7 +165,7 @@ async fn can_google_authorization_url() -> Result<(), Box serde_urlencoded::to_string([("scope", &settings.scope)])?, ]; - testing::request::(|request, ctx| async move { + testing::request::(|request, _ctx| async move { // Test the authorization url let res = request.get("/api/oauth2/google").await; assert_eq!(res.status_code(), 200); @@ -183,7 +183,7 @@ async fn can_call_google_callback() -> Result<(), Box> { let settings = set_default_url().await; // mock oauth2 server mock_oauth_server(&settings, true).await?; - testing::request::(|request, ctx| async move { + testing::request::(|request, _ctx| async move { // Get the authorization url from the server let auth_res = request.get("/api/oauth2/google").await; // Cookie for csrf token @@ -279,7 +279,7 @@ async fn cannot_call_callback_twice_with_same_csrf_token() -> Result<(), Box(|request, ctx| async move { + testing::request::(|request, _ctx| async move { // Get the authorization url from the server let auth_res = request.get("/api/oauth2/google").await; // Cookie for csrf token @@ -331,7 +331,7 @@ pub async fn cannot_call_google_callback_without_csrf_token( let settings = set_default_url().await; // Mock oauth2 server mock_oauth_server(&settings, false).await?; - testing::request::(|request, ctx| async move { + testing::request::(|request, _ctx| async move { // Test the google callback without csrf token let res = request .get("/api/oauth2/google/callback") @@ -349,7 +349,7 @@ pub async fn cannot_call_google_callback_without_csrf_token( #[tokio::test] #[serial] pub async fn cannot_call_protect_without_cookie() -> Result<(), Box> { - testing::request::(|request, ctx| async move { + testing::request::(|request, _ctx| async move { // hit the protected url without cookies let res = request.get("/api/oauth2/protected").await; assert_eq!(res.status_code(), 401); diff --git a/src/config.rs b/src/config.rs index afe5908..b995b45 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,5 @@ use crate::error::OAuth2StoreError; -use crate::grants::authorization_code::{ - AuthorizationCodeCookieConfig, AuthorizationCodeCredentials, AuthorizationCodeUrlConfig, -}; +use crate::grants::authorization_code::{CookieConfig, Credentials, UrlConfig}; use serde::de::Error; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -34,21 +32,21 @@ use std::str::FromStr; /// timeout_seconds: 600 # Optional, default 600 seconds /// ``` #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct OAuth2Config { +pub struct Config { pub secret_key: Option>, - pub authorization_code: Vec, + pub authorization_code: Vec, } #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct AuthorizationCodeConfig { +pub struct AuthorizationCode { pub client_identifier: String, - pub client_credentials: AuthorizationCodeCredentials, - pub url_config: AuthorizationCodeUrlConfig, - pub cookie_config: AuthorizationCodeCookieConfig, + pub client_credentials: Credentials, + pub url_config: UrlConfig, + pub cookie_config: CookieConfig, pub timeout_seconds: Option, } -impl TryFrom for OAuth2Config { +impl TryFrom for Config { type Error = OAuth2StoreError; #[tracing::instrument(name = "Convert Value to OAuth2Config")] fn try_from(value: Value) -> Result { @@ -59,7 +57,7 @@ impl TryFrom for OAuth2Config { .collect() }); - let authorization_code: Vec = value + let authorization_code: Vec = value .get("authorization_code") .and_then(|v| v.as_array()) .ok_or_else(|| { diff --git a/src/controllers/middleware/mod.rs b/src/controllers/middleware/mod.rs index c5110fd..9b19129 100644 --- a/src/controllers/middleware/mod.rs +++ b/src/controllers/middleware/mod.rs @@ -1,4 +1,4 @@ mod auth; -mod private_cookie_jar; +mod private; pub use auth::*; -pub use private_cookie_jar::*; +pub use private::*; diff --git a/src/controllers/middleware/private_cookie_jar.rs b/src/controllers/middleware/private.rs similarity index 94% rename from src/controllers/middleware/private_cookie_jar.rs rename to src/controllers/middleware/private.rs index bf37390..ace3261 100644 --- a/src/controllers/middleware/private_cookie_jar.rs +++ b/src/controllers/middleware/private.rs @@ -1,5 +1,5 @@ -use crate::grants::authorization_code::AuthorizationCodeCookieConfig; -use crate::{url, OAuth2ClientStore, COOKIE_NAME}; +use crate::grants::authorization_code::CookieConfig; +use crate::{base_oauth2::url, OAuth2ClientStore, COOKIE_NAME}; use async_trait::async_trait; use axum::response::{IntoResponse, IntoResponseParts, ResponseParts}; use axum::{ @@ -44,6 +44,7 @@ impl AsMut for OAuth2PrivateCookieJar { impl OAuth2PrivateCookieJar { #[must_use] #[allow(unused_mut)] + #[allow(clippy::should_implement_trait)] pub fn add>>(mut self, cookie: C) -> Self { Self(self.0.add(cookie.into())) } @@ -84,7 +85,7 @@ pub trait OAuth2PrivateCookieJarTrait: Clone { /// # Errors /// * `Error` - When the cookie cannot be created fn create_short_live_cookie_with_token_response( - config: &AuthorizationCodeCookieConfig, + config: &CookieConfig, token: &BasicTokenResponse, jar: Self, ) -> loco_rs::prelude::Result; @@ -92,7 +93,7 @@ pub trait OAuth2PrivateCookieJarTrait: Clone { impl OAuth2PrivateCookieJarTrait for OAuth2PrivateCookieJar { fn create_short_live_cookie_with_token_response( - config: &AuthorizationCodeCookieConfig, + config: &CookieConfig, token: &BasicTokenResponse, jar: Self, ) -> loco_rs::prelude::Result { @@ -140,8 +141,8 @@ where let Extension(store) = parts .extract::>() .await - .map_err(|err| err.into_response())?; - let key = store.key.clone(); + .map_err(axum::response::IntoResponse::into_response)?; + let key = store.key; let jar = extract::cookie::PrivateCookieJar::from_headers(&parts.headers, key); Ok(Self(jar)) } @@ -149,14 +150,15 @@ where #[cfg(test)] mod tests { use super::*; - use crate::http::StatusCode; + use crate::base_oauth2::http::StatusCode; use axum::routing::get; use axum::Router; use axum_extra::extract::PrivateCookieJar; use axum_test::TestServer; use http::header::{HeaderValue, COOKIE}; - use loco_rs::config::{Config, Database, Middlewares, Server}; + use loco_rs::config::{Config, Database, Logger, Middlewares, Server, Workers}; use loco_rs::environment::Environment; + use sea_orm::DatabaseConnection; use serde_json::json; use std::collections::BTreeMap; @@ -168,11 +170,11 @@ mod tests { fn create_default_app_context() -> AppContext { AppContext { environment: Environment::Production, - db: Default::default(), + db: DatabaseConnection::default(), redis: None, config: Config { initializers: None, - logger: Default::default(), + logger: Logger::default(), server: Server { binding: "test-binding".to_string(), port: 8080, @@ -190,7 +192,7 @@ mod tests { }, }, database: Database { - uri: "".to_string(), + uri: String::new(), enable_logging: false, min_connections: 0, max_connections: 0, @@ -202,7 +204,7 @@ mod tests { }, redis: None, auth: None, - workers: Default::default(), + workers: Workers::default(), mailer: None, settings: None, }, @@ -224,7 +226,7 @@ mod tests { let key = create_key(); let mut headers = HeaderMap::new(); headers.insert(COOKIE, HeaderValue::from_static("")); - let jar = OAuth2PrivateCookieJar::from_headers(&headers, key.clone()); + let jar = OAuth2PrivateCookieJar::from_headers(&headers, key); let cookie_name = "test_cookie"; let cookie_value = "test_value"; @@ -248,7 +250,7 @@ mod tests { let key = create_key(); let mut headers = HeaderMap::new(); headers.insert(COOKIE, HeaderValue::from_static("")); - let jar = OAuth2PrivateCookieJar::from_headers(&headers, key.clone()); + let jar = OAuth2PrivateCookieJar::from_headers(&headers, key); let cookie_name = "test_cookie"; let cookie_value = "test_value"; @@ -293,7 +295,7 @@ mod tests { // Simulate receiving a request with the encrypted cookie let mut headers = HeaderMap::new(); headers.insert("cookie", encrypted_cookie_value.parse().unwrap()); - let private_jar = PrivateCookieJar::from_headers(&HeaderMap::new(), key.clone()); + let private_jar = PrivateCookieJar::from_headers(&HeaderMap::new(), key); let mut original_cookie = None; for cookie in cookies_from_request(&headers) { if let Some(cookie) = private_jar.decrypt(cookie) { @@ -314,7 +316,7 @@ mod tests { let key = create_key(); let mut headers = HeaderMap::new(); headers.insert(COOKIE, HeaderValue::from_static("")); - let jar = OAuth2PrivateCookieJar::from_headers(&headers, key.clone()); + let jar = OAuth2PrivateCookieJar::from_headers(&headers, key); let cookie_name = "test_cookie"; let cookie_value = "test_value"; diff --git a/src/controllers/oauth2.rs b/src/controllers/oauth2.rs index 1b2c29f..64cf1b2 100644 --- a/src/controllers/oauth2.rs +++ b/src/controllers/oauth2.rs @@ -11,7 +11,7 @@ use tokio::sync::MutexGuard; use crate::controllers::middleware::OAuth2PrivateCookieJarTrait; use crate::controllers::middleware::{OAuth2CookieUser, OAuth2PrivateCookieJar}; -use crate::grants::authorization_code::AuthorizationCodeGrantTrait; +use crate::grants::authorization_code::GrantTrait; use crate::models::oauth2_sessions::OAuth2SessionsTrait; use crate::models::users::OAuth2UserTrait; @@ -32,7 +32,7 @@ pub struct AuthParams { /// * `String` - The authorization URL pub async fn get_authorization_url( session: Session, - oauth2_client: &mut MutexGuard<'_, dyn AuthorizationCodeGrantTrait>, + oauth2_client: &mut MutexGuard<'_, dyn GrantTrait>, ) -> String { let (auth_url, csrf_token) = oauth2_client.get_authorization_url(); session.set("CSRF_TOKEN", csrf_token.secret().to_owned()); @@ -43,7 +43,7 @@ pub async fn get_authorization_url + ModelTrait, V: OAuth2SessionsTrait, W: DatabasePool + Clone + Debug + Sync + Send + 'static, @@ -68,7 +68,7 @@ pub async fn callback< params: AuthParams, // Extract the private cookie jar from the request jar: OAuth2PrivateCookieJar, - client: &mut MutexGuard<'_, dyn AuthorizationCodeGrantTrait>, + client: &mut MutexGuard<'_, dyn GrantTrait>, ) -> Result { // Get the CSRF token from the session let csrf_token = session @@ -80,7 +80,10 @@ pub async fn callback< .await .map_err(|e| Error::BadRequest(e.to_string()))?; // Get the user profile - let profile = profile.json::().await.unwrap(); + let profile = profile.json::().await.map_err(|e| { + tracing::error!("Error getting profile: {:?}", e); + Error::InternalServerError + })?; let user = U::upsert_with_oauth(&ctx.db, &profile) .await .map_err(|_e| { @@ -133,6 +136,7 @@ pub async fn google_authorization_url + ModelTrait, V: OAuth2SessionsTrait, W: DatabasePool + Clone + Debug + Sync + Send + 'static, @@ -176,16 +180,30 @@ pub async fn google_callback< Error::InternalServerError })?; let response = callback::(ctx, session, params, jar, &mut client).await?; + drop(client); Ok(response) } +/// The protected URL for the `OAuth2` flow +/// This will return a message indicating that the user is protected +/// +/// # Generics +/// * `T` - The user profile, should implement `DeserializeOwned` and `Send` +/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait` +/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait` +/// # Arguments +/// * `user` - The `OAuth2CookieUser` that holds the user and the session +/// # Returns +/// The response with the message indicating that the user is protected +/// # Errors +/// * `loco_rs::errors::Error` - When the user cannot be retrieved pub async fn protected< - T: DeserializeOwned, + T: DeserializeOwned + Send, U: OAuth2UserTrait + ModelTrait, V: OAuth2SessionsTrait + ModelTrait, >( user: OAuth2CookieUser, ) -> Result { let _user = user.as_ref(); - Ok(format!("You are protected!")) + Ok("You are protected!".to_string()) } diff --git a/src/error.rs b/src/error.rs index f237830..32de0ca 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,12 +10,12 @@ pub enum OAuth2StoreError { ClientNotFound, /// Error for client already exists but different `OAuth2ClientGrantEnum` ClientTypeMismatch(String, OAuth2ClientGrantEnum), - /// Error parsing from configuration JSON to OAuth2 configuration - ConfigJsonError(#[from] serde_json::Error), + /// Error parsing from configuration JSON to `OAuth2` configuration + ConfigJson(#[from] serde_json::Error), /// Error for client creation - ClientCreationError(#[from] OAuth2ClientError), + ClientCreation(#[from] OAuth2ClientError), /// Error for converting key from `[u8]` to `Key` - KeyConversionError(#[from] cookie::KeyError), + KeyConversion(#[from] cookie::KeyError), } impl OAuth2StoreError { @@ -25,20 +25,20 @@ impl OAuth2StoreError { Self::ClientNotFound => "Client not found".to_string(), Self::ClientTypeMismatch(credential_identifier, client) => match client { OAuth2ClientGrantEnum::AuthorizationCode(_) => { - format!("Authorization Code client already exists with credential identifier: {}", credential_identifier) + format!("Authorization Code client already exists with credential identifier: {credential_identifier}", ) } OAuth2ClientGrantEnum::ClientCredentials => { - format!("Client Credentials client already exists with credential identifier: {}", credential_identifier) + format!("Client Credentials client already exists with credential identifier: {credential_identifier}", ) } - OAuth2ClientGrantEnum::DeviceCode => format!("Device Code client already exists with credential identifier: {}", credential_identifier), - OAuth2ClientGrantEnum::Implicit => format!("Implicit client already exists with credential identifier: {}", credential_identifier), + OAuth2ClientGrantEnum::DeviceCode => format!("Device Code client already exists with credential identifier: {credential_identifier}", ), + OAuth2ClientGrantEnum::Implicit => format!("Implicit client already exists with credential identifier: {credential_identifier}", ), OAuth2ClientGrantEnum::ResourceOwnerPasswordCredentials => format!( - "Resource Owner Password Credentials client already exists with credential identifier: {}", credential_identifier + "Resource Owner Password Credentials client already exists with credential identifier: {credential_identifier}", ), }, - Self::ConfigJsonError(err) => format!("Cannot parse JSON for OAuth2 configuration: {}", err), - Self::ClientCreationError(err) => format!("Error creating client: {}", err), - Self::KeyConversionError(err) => format!("Error converting key: {}", err), + Self::ConfigJson(err) => format!("Cannot parse JSON for OAuth2 configuration: {err}", ), + Self::ClientCreation(err) => format!("Error creating client: {err}", ), + Self::KeyConversion(err) => format!("Error converting key: {err}", ), } } } diff --git a/src/grants/authorization_code.rs b/src/grants/authorization_code.rs index 58489eb..5113bb1 100644 --- a/src/grants/authorization_code.rs +++ b/src/grants/authorization_code.rs @@ -15,17 +15,17 @@ use subtle::ConstantTimeEq; use crate::error::{OAuth2ClientError, OAuth2ClientResult}; /// A credentials struct that holds the `OAuth2` client credentials. - For -/// [`AuthorizationCodeClient`] +/// [`Client`] #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct AuthorizationCodeCredentials { +pub struct Credentials { pub client_id: String, pub client_secret: Option, } /// A url config struct that holds the `OAuth2` client related URLs. - For -/// [`AuthorizationCodeClient`] +/// [`Client`] #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct AuthorizationCodeUrlConfig { +pub struct UrlConfig { pub auth_url: String, pub token_url: String, pub redirect_url: String, @@ -34,15 +34,15 @@ pub struct AuthorizationCodeUrlConfig { } /// An url config struct that holds the Cookie related URLs. - For -/// [`AuthorizationCodeClient`] +/// [`Client`] #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct AuthorizationCodeCookieConfig { +pub struct CookieConfig { pub protected_url: Option, } -/// [`AuthorizationCodeClient`] that acts as a client for the Authorization Code +/// [`Client`] that acts as a client for the Authorization Code /// Grant flow. -pub struct AuthorizationCodeClient { +pub struct Client { /// [`BasicClient`] instance for the `OAuth2` client. pub oauth2: BasicClient, /// [`Url`] instance for the `OAuth2` client's profile URL. @@ -57,22 +57,22 @@ pub struct AuthorizationCodeClient { /// A [`std::time::Duration`] for the `OAuth2` client's CSRF token timeout /// which defaults to 10 minutes (600s). pub csrf_token_timeout: std::time::Duration, - /// An optional [`AuthorizationCodeCookieConfig`] for the `OAuth2` client's + /// An optional [`CookieConfig`] for the `OAuth2` client's /// cookie during middleware - pub cookie_config: AuthorizationCodeCookieConfig, + pub cookie_config: CookieConfig, } -impl AuthorizationCodeClient { +impl Client { /// Create a new instance of [`OAuth2Client`]. /// # Arguments - /// * `credentials` - A [`AuthorizationCodeCredentials`] struct that holds + /// * `credentials` - A [`Credentials`] struct that holds /// the `OAuth2` client credentials. - /// * `config` - A [`AuthorizationCodeUrlConfig`] struct that holds the + /// * `config` - A [`UrlConfig`] struct that holds the /// `OAuth2` client related URLs. /// * `timeout_seconds` - An optional timeout in seconds for the csrf token. /// Defaults to 10 minutes (600s). /// # Returns - /// A [`AuthorizationCodeClient`] instance + /// A [`Client`] instance /// # Errors /// [`OAuth2ClientError::UrlError`] if the `auth_url`, `token_url`, /// `redirect_url` or `profile_url` is invalid. @@ -93,9 +93,9 @@ impl AuthorizationCodeClient { /// let client = AuthorizationCodeClient::new(credentials, config, None)?; /// ``` pub fn new( - credentials: AuthorizationCodeCredentials, - config: AuthorizationCodeUrlConfig, - cookie_config: AuthorizationCodeCookieConfig, + credentials: Credentials, + config: UrlConfig, + cookie_config: CookieConfig, timeout_seconds: Option, ) -> OAuth2ClientResult { let client_id = ClientId::new(credentials.client_id); @@ -121,7 +121,7 @@ impl AuthorizationCodeClient { cookie_config, }) } - /// Remove expired flow states within the [`AuthorizationCodeClient`]. + /// Remove expired flow states within the [`Client`]. /// # Example /// ```rust,ignore /// client.remove_expire_flow(); // Clear outdated states within client.flow_states @@ -149,16 +149,16 @@ impl AuthorizationCodeClient { } #[async_trait::async_trait] -pub trait AuthorizationCodeGrantTrait: Send + Sync { +pub trait GrantTrait: Send + Sync { /// Get authorization code client /// # Returns - /// A mutable reference to the [`AuthorizationCodeClient`] instance. - fn get_authorization_code_client(&mut self) -> &mut AuthorizationCodeClient; + /// A mutable reference to the [`Client`] instance. + fn get_authorization_code_client(&mut self) -> &mut Client; /// Get `AuthorizationCodeCookieConfig` instance /// # Returns /// A reference to the `AuthorizationCodeCookieConfig` instance. - fn get_cookie_config(&self) -> &AuthorizationCodeCookieConfig; + fn get_cookie_config(&self) -> &CookieConfig; /// Get authorization URL /// # Returns @@ -239,13 +239,13 @@ pub trait AuthorizationCodeGrantTrait: Send + Sync { /// query. /// * `csrf_token` - A string containing the CSRF token saved in the /// temporary session after the - /// [`AuthorizationCodeClient::get_authorization_url`] method. + /// [`Client::get_authorization_url`] method. /// # Returns /// A tuple containing the token response and the profile response. /// [`BasicTokenResponse`] is the token response from the OAuth2 provider. /// [`Response`] is the profile response from the OAuth2 provider which /// describes the user's profile. This response json information will be - /// determined by [`AuthorizationCodeClient::scopes`] # Errors + /// determined by [`Client::scopes`] # Errors /// An [`OAuth2ClientError::CsrfTokenError`] if the csrf token is invalid. /// An [`OAuth2ClientError::BasicTokenError`] if the token /// exchange fails. @@ -314,7 +314,6 @@ pub trait AuthorizationCodeGrantTrait: Send + Sync { /// Ok((jar, Redirect::to("/protected"))) /// } /// - #[must_use] async fn verify_code_from_callback( &mut self, code: String, @@ -325,7 +324,7 @@ pub trait AuthorizationCodeGrantTrait: Send + Sync { // Clear outdated flow states client.remove_expire_flow(); // Compare csrf token, use subtle to prevent time attack - if !AuthorizationCodeClient::constant_time_compare(&csrf_token, &state) { + if !Client::constant_time_compare(&csrf_token, &state) { return Err(OAuth2ClientError::CsrfTokenError); } // Get the pkce_verifier for exchanging code @@ -353,11 +352,11 @@ pub trait AuthorizationCodeGrantTrait: Send + Sync { } } -impl AuthorizationCodeGrantTrait for AuthorizationCodeClient { - fn get_authorization_code_client(&mut self) -> &mut AuthorizationCodeClient { +impl GrantTrait for Client { + fn get_authorization_code_client(&mut self) -> &mut Client { self } - fn get_cookie_config(&self) -> &AuthorizationCodeCookieConfig { + fn get_cookie_config(&self) -> &CookieConfig { &self.cookie_config } } @@ -420,11 +419,11 @@ mod tests { client_id: "test_client_id".to_string(), client_secret: "test_client_secret".to_string(), code: "test_code".to_string(), - auth_url: format!("{}/auth_url", url), - token_url: format!("{}/token_url", url), - redirect_url: format!("{}/redirect_url", url), - profile_url: format!("{}/profile_url", url), - scope: format!("{}/scope_1", url), + auth_url: format!("{url}/auth_url",), + token_url: format!("{url}/token_url",), + redirect_url: format!("{url}/redirect_url",), + profile_url: format!("{url}/profile_url",), + scope: format!("{url}/scope_1",), exchange_mock_body, profile_mock_body: user_profile, mock_server: server, @@ -437,46 +436,47 @@ mod tests { let host = url.host_str().unwrap_or_default(); // Get the host as a str, default to empty string if not present let path = url.path(); - match url.port() { - Some(port) => format!("{}://{}:{}{}", scheme, host, port, path), - None => format!("{}://{}{}", scheme, host, path), - } + url.port().map_or_else( + || format!("{scheme}://{host}{path}"), + |port| format!("{scheme}://{host}:{port}{path}"), + ) } - async fn create_client() -> OAuth2ClientResult<(AuthorizationCodeClient, Settings)> { + async fn create_client() -> OAuth2ClientResult<(Client, Settings)> { let settings = Settings::new().await; - let credentials = AuthorizationCodeCredentials { + let credentials = Credentials { client_id: settings.client_id.to_string(), client_secret: Some(settings.client_secret.to_string()), }; - let url_config = AuthorizationCodeUrlConfig { + let url_config = UrlConfig { auth_url: settings.auth_url.to_string(), token_url: settings.token_url.to_string(), redirect_url: settings.redirect_url.to_string(), profile_url: settings.profile_url.to_string(), scopes: vec![settings.scope.to_string()], }; - let cookie_config = AuthorizationCodeCookieConfig { + let cookie_config = CookieConfig { protected_url: None, }; - let client = AuthorizationCodeClient::new(credentials, url_config, cookie_config, None)?; + let client = Client::new(credentials, url_config, cookie_config, None)?; Ok((client, settings)) } #[derive(thiserror::Error, Debug)] enum TestError { #[error(transparent)] - OAuth2ClientError(#[from] OAuth2ClientError), + OAuth2Client(#[from] OAuth2ClientError), #[error(transparent)] - ReqwestError(reqwest::Error), + #[allow(dead_code)] + Reqwest(reqwest::Error), #[error("Couldnt find {0}")] - QueryMapError(String), + QueryMap(String), #[error("Unable to deserialize profile")] - ProfileError, + Profile, #[error("Mock json data parse Error")] - MockJsonDataError(#[from] serde_json::Error), + MockJsonData(#[from] serde_json::Error), #[error("Mock form data error")] - MockFormDataError(#[from] serde_urlencoded::ser::Error), + MockFormData(#[from] serde_urlencoded::ser::Error), } #[tokio::test] @@ -485,44 +485,39 @@ mod tests { let (url, csrf_token) = client.get_authorization_url(); let base_url_with_path = get_base_url_with_path(&url); // compare between the auth_url with the base url - assert_eq!(settings.auth_url.to_string(), base_url_with_path); + assert_eq!(settings.auth_url, base_url_with_path); let query_map_multi: HashMap> = form_urlencoded::parse(url.query().unwrap_or("").as_bytes()) .into_owned() .fold(std::collections::HashMap::new(), |mut acc, (key, value)| { - acc.entry(key).or_insert_with(Vec::new).push(value); + acc.entry(key).or_default().push(value); acc }); // Check response type - let response_type = - query_map_multi - .get("response_type") - .ok_or(TestError::QueryMapError( - "Couldnt find response type".to_string(), - ))?; + let response_type = query_map_multi + .get("response_type") + .ok_or(TestError::QueryMap( + "Couldnt find response type".to_string(), + ))?; assert_eq!(response_type[0], "code"); let client_id = query_map_multi .get("client_id") - .ok_or(TestError::QueryMapError( - "Couldnt find client id".to_string(), - ))?; + .ok_or(TestError::QueryMap("Couldnt find client id".to_string()))?; assert_eq!(client_id[0], settings.client_id); // Check redirect url let redirect_url = query_map_multi .get("redirect_uri") - .ok_or(TestError::QueryMapError( - "Couldnt find redirect url".to_string(), - ))?; + .ok_or(TestError::QueryMap("Couldnt find redirect url".to_string()))?; assert_eq!(redirect_url[0], settings.redirect_url); // Check scopes let scopes = query_map_multi .get("scope") - .ok_or(TestError::QueryMapError("Couldnt find scopes".to_string()))?; + .ok_or(TestError::QueryMap("Couldnt find scopes".to_string()))?; assert_eq!(scopes[0], settings.scope); // Check state let state = query_map_multi .get("state") - .ok_or(TestError::QueryMapError("Couldnt find state".to_string()))?; + .ok_or(TestError::QueryMap("Couldnt find state".to_string()))?; assert_eq!(state[0], csrf_token.secret().to_owned()); Ok(()) } @@ -584,7 +579,7 @@ mod tests { let profile = profile .json::() .await - .map_err(|_| TestError::ProfileError)?; + .map_err(|_| TestError::Profile)?; assert_eq!(profile.email, "test_email"); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index ad8ff8f..5efacfb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(elided_lifetimes_in_paths)] use crate::oauth2_grant::OAuth2ClientGrantEnum; use axum::extract::FromRef; use axum_extra::extract::cookie::Key; @@ -13,10 +14,10 @@ pub mod models; pub mod oauth2_grant; const COOKIE_NAME: &str = "sid"; -use crate::config::OAuth2Config; +use crate::config::Config; use crate::error::{OAuth2ClientResult, OAuth2StoreError, OAuth2StoreResult}; -use crate::grants::authorization_code::{AuthorizationCodeClient, AuthorizationCodeGrantTrait}; -pub use oauth2::*; +use crate::grants::authorization_code::{Client, GrantTrait}; +pub use oauth2 as base_oauth2; use tokio::sync::{Mutex, MutexGuard}; #[derive(Clone)] @@ -29,10 +30,13 @@ impl OAuth2ClientStore { /// Create a new instance of `OAuth2ClientStore`. /// # Arguments /// * `config` - An instance of `OAuth2Config` that holds the `OAuth2` configuration. + /// /// # Returns - /// * `OAuth2StoreResult` - A result that holds the `OAuth2ClientStore` if successful, otherwise an `OAuth2StoreError`. - #[must_use] - pub fn new(config: OAuth2Config) -> OAuth2StoreResult { + /// * `Self` - An instance of `OAuth2ClientStore`. + /// + /// # Errors + /// * `OAuth2StoreError` - An error indicating the failure to create the `OAuth2ClientStore`. + pub fn new(config: Config) -> OAuth2StoreResult { let mut clients = BTreeMap::new(); Self::insert_authorization_code_clients(&mut clients, config.authorization_code)?; let key = match config.secret_key { @@ -53,14 +57,14 @@ impl OAuth2ClientStore { )] fn insert_authorization_code_clients( clients: &mut BTreeMap, - authorization_code: Vec, + authorization_code: Vec, ) -> OAuth2ClientResult<()> { for grant in authorization_code { tracing::info!( "Creating Authorization Code Grant client: {:?}", grant.client_identifier ); - let client = AuthorizationCodeClient::new( + let client = Client::new( grant.client_credentials, grant.url_config, grant.cookie_config, @@ -93,10 +97,10 @@ impl OAuth2ClientStore { /// # Returns /// * `OAuth2StoreResult<&AuthorizationCodeClient>` - A result that holds a reference to the `AuthorizationCodeClient` if found, otherwise an `OAuth2StoreError`. #[tracing::instrument(name = "Get Authorization Code Grant client", skip(self))] - pub async fn get_authorization_code_client + std::fmt::Debug>( + pub async fn get_authorization_code_client + std::fmt::Debug + Send>( &self, client_identifier: T, - ) -> OAuth2StoreResult> { + ) -> OAuth2StoreResult> { match self.get(&client_identifier) { Some(OAuth2ClientGrantEnum::AuthorizationCode(client)) => { let client = client.lock().await; diff --git a/src/migration/m20240101_000000_oauth2_sessions.rs b/src/migration/m20240101_000000_oauth2_sessions.rs index 576075c..1a19130 100644 --- a/src/migration/m20240101_000000_oauth2_sessions.rs +++ b/src/migration/m20240101_000000_oauth2_sessions.rs @@ -1,4 +1,7 @@ -use sea_orm_migration::{prelude::*, schema::*}; +use sea_orm_migration::{ + prelude::*, + schema::{integer, pk_auto, string, table_auto, timestamp}, +}; #[derive(DeriveMigrationName)] pub struct Migration; diff --git a/src/oauth2_grant.rs b/src/oauth2_grant.rs index 911de36..09427d9 100644 --- a/src/oauth2_grant.rs +++ b/src/oauth2_grant.rs @@ -2,11 +2,11 @@ use std::sync::Arc; use tokio::sync::Mutex; -use crate::grants::authorization_code::AuthorizationCodeGrantTrait; +use crate::grants::authorization_code::GrantTrait; #[derive(Clone)] pub enum OAuth2ClientGrantEnum { - AuthorizationCode(Arc>), + AuthorizationCode(Arc>), ClientCredentials, DeviceCode, Implicit,