Skip to content

Commit

Permalink
fix test and clippy for lib
Browse files Browse the repository at this point in the history
  • Loading branch information
yinho999 committed Jun 9, 2024
1 parent 325743a commit 77923fb
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 138 deletions.
4 changes: 2 additions & 2 deletions examples/demo/src/initializers/oauth2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{async_trait, Extension, Router as AxumRouter};
use loco_oauth2::{config::OAuth2Config, OAuth2ClientStore};
use loco_oauth2::{config::Config, OAuth2ClientStore};
use loco_rs::prelude::*;

pub struct OAuth2StoreInitializer;
Expand All @@ -19,7 +19,7 @@ impl Initializer for OAuth2StoreInitializer {
.get("oauth2")
.ok_or(Error::Message("oauth2 config not found".to_string()))?
.clone();
let oauth2_config: OAuth2Config = oauth2_config_value.try_into().map_err(|e| {
let oauth2_config: Config = oauth2_config_value.try_into().map_err(|e| {
tracing::error!(error = ?e, "could not convert oauth2 config");
Error::Message("could not convert oauth2 config".to_string())
})?;
Expand Down
3 changes: 2 additions & 1 deletion examples/demo/src/models/o_auth2_sessions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use async_trait::async_trait;
use chrono::Local;
use loco_oauth2::{
basic::BasicTokenResponse, models::oauth2_sessions::OAuth2SessionsTrait, TokenResponse,
base_oauth2::basic::BasicTokenResponse, base_oauth2::TokenResponse,
models::oauth2_sessions::OAuth2SessionsTrait,
};
use loco_rs::model::{ModelError, ModelResult};
use sea_orm::{entity::prelude::*, ActiveValue, TransactionTrait};
Expand Down
10 changes: 5 additions & 5 deletions examples/demo/tests/requests/oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async fn can_google_authorization_url() -> Result<(), Box<dyn std::error::Error>
serde_urlencoded::to_string([("scope", &settings.scope)])?,
];

testing::request::<App, _, _>(|request, ctx| async move {
testing::request::<App, _, _>(|request, _ctx| async move {
// Test the authorization url
let res = request.get("/api/oauth2/google").await;
assert_eq!(res.status_code(), 200);
Expand All @@ -183,7 +183,7 @@ async fn can_call_google_callback() -> Result<(), Box<dyn std::error::Error>> {
let settings = set_default_url().await;
// mock oauth2 server
mock_oauth_server(&settings, true).await?;
testing::request::<App, _, _>(|request, ctx| async move {
testing::request::<App, _, _>(|request, _ctx| async move {
// Get the authorization url from the server
let auth_res = request.get("/api/oauth2/google").await;
// Cookie for csrf token
Expand Down Expand Up @@ -279,7 +279,7 @@ async fn cannot_call_callback_twice_with_same_csrf_token() -> Result<(), Box<dyn
let settings = set_default_url().await;
// mock oauth2 server
mock_oauth_server(&settings, true).await?;
testing::request::<App, _, _>(|request, ctx| async move {
testing::request::<App, _, _>(|request, _ctx| async move {
// Get the authorization url from the server
let auth_res = request.get("/api/oauth2/google").await;
// Cookie for csrf token
Expand Down Expand Up @@ -331,7 +331,7 @@ pub async fn cannot_call_google_callback_without_csrf_token(
let settings = set_default_url().await;
// Mock oauth2 server
mock_oauth_server(&settings, false).await?;
testing::request::<App, _, _>(|request, ctx| async move {
testing::request::<App, _, _>(|request, _ctx| async move {
// Test the google callback without csrf token
let res = request
.get("/api/oauth2/google/callback")
Expand All @@ -349,7 +349,7 @@ pub async fn cannot_call_google_callback_without_csrf_token(
#[tokio::test]
#[serial]
pub async fn cannot_call_protect_without_cookie() -> Result<(), Box<dyn std::error::Error>> {
testing::request::<App, _, _>(|request, ctx| async move {
testing::request::<App, _, _>(|request, _ctx| async move {
// hit the protected url without cookies
let res = request.get("/api/oauth2/protected").await;
assert_eq!(res.status_code(), 401);
Expand Down
20 changes: 9 additions & 11 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::error::OAuth2StoreError;
use crate::grants::authorization_code::{
AuthorizationCodeCookieConfig, AuthorizationCodeCredentials, AuthorizationCodeUrlConfig,
};
use crate::grants::authorization_code::{CookieConfig, Credentials, UrlConfig};
use serde::de::Error;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -34,21 +32,21 @@ use std::str::FromStr;
/// timeout_seconds: 600 # Optional, default 600 seconds
/// ```
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OAuth2Config {
pub struct Config {
pub secret_key: Option<Vec<u8>>,
pub authorization_code: Vec<AuthorizationCodeConfig>,
pub authorization_code: Vec<AuthorizationCode>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthorizationCodeConfig {
pub struct AuthorizationCode {
pub client_identifier: String,
pub client_credentials: AuthorizationCodeCredentials,
pub url_config: AuthorizationCodeUrlConfig,
pub cookie_config: AuthorizationCodeCookieConfig,
pub client_credentials: Credentials,
pub url_config: UrlConfig,
pub cookie_config: CookieConfig,
pub timeout_seconds: Option<u64>,
}

impl TryFrom<Value> for OAuth2Config {
impl TryFrom<Value> for Config {
type Error = OAuth2StoreError;
#[tracing::instrument(name = "Convert Value to OAuth2Config")]
fn try_from(value: Value) -> Result<Self, Self::Error> {
Expand All @@ -59,7 +57,7 @@ impl TryFrom<Value> for OAuth2Config {
.collect()
});

let authorization_code: Vec<AuthorizationCodeConfig> = value
let authorization_code: Vec<AuthorizationCode> = value
.get("authorization_code")
.and_then(|v| v.as_array())
.ok_or_else(|| {
Expand Down
4 changes: 2 additions & 2 deletions src/controllers/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod auth;
mod private_cookie_jar;
mod private;
pub use auth::*;
pub use private_cookie_jar::*;
pub use private::*;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::grants::authorization_code::AuthorizationCodeCookieConfig;
use crate::{url, OAuth2ClientStore, COOKIE_NAME};
use crate::grants::authorization_code::CookieConfig;
use crate::{base_oauth2::url, OAuth2ClientStore, COOKIE_NAME};
use async_trait::async_trait;
use axum::response::{IntoResponse, IntoResponseParts, ResponseParts};
use axum::{
Expand Down Expand Up @@ -44,6 +44,7 @@ impl AsMut<extract::cookie::PrivateCookieJar> for OAuth2PrivateCookieJar {
impl OAuth2PrivateCookieJar {
#[must_use]
#[allow(unused_mut)]
#[allow(clippy::should_implement_trait)]
pub fn add<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
Self(self.0.add(cookie.into()))
}
Expand Down Expand Up @@ -84,15 +85,15 @@ pub trait OAuth2PrivateCookieJarTrait: Clone {
/// # Errors
/// * `Error` - When the cookie cannot be created
fn create_short_live_cookie_with_token_response(
config: &AuthorizationCodeCookieConfig,
config: &CookieConfig,
token: &BasicTokenResponse,
jar: Self,
) -> loco_rs::prelude::Result<Self>;
}

impl OAuth2PrivateCookieJarTrait for OAuth2PrivateCookieJar {
fn create_short_live_cookie_with_token_response(
config: &AuthorizationCodeCookieConfig,
config: &CookieConfig,
token: &BasicTokenResponse,
jar: Self,
) -> loco_rs::prelude::Result<Self> {
Expand Down Expand Up @@ -140,23 +141,24 @@ where
let Extension(store) = parts
.extract::<Extension<OAuth2ClientStore>>()
.await
.map_err(|err| err.into_response())?;
let key = store.key.clone();
.map_err(axum::response::IntoResponse::into_response)?;
let key = store.key;
let jar = extract::cookie::PrivateCookieJar::from_headers(&parts.headers, key);
Ok(Self(jar))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::StatusCode;
use crate::base_oauth2::http::StatusCode;
use axum::routing::get;
use axum::Router;
use axum_extra::extract::PrivateCookieJar;
use axum_test::TestServer;
use http::header::{HeaderValue, COOKIE};
use loco_rs::config::{Config, Database, Middlewares, Server};
use loco_rs::config::{Config, Database, Logger, Middlewares, Server, Workers};
use loco_rs::environment::Environment;
use sea_orm::DatabaseConnection;
use serde_json::json;
use std::collections::BTreeMap;

Expand All @@ -168,11 +170,11 @@ mod tests {
fn create_default_app_context() -> AppContext {
AppContext {
environment: Environment::Production,
db: Default::default(),
db: DatabaseConnection::default(),
redis: None,
config: Config {
initializers: None,
logger: Default::default(),
logger: Logger::default(),
server: Server {
binding: "test-binding".to_string(),
port: 8080,
Expand All @@ -190,7 +192,7 @@ mod tests {
},
},
database: Database {
uri: "".to_string(),
uri: String::new(),
enable_logging: false,
min_connections: 0,
max_connections: 0,
Expand All @@ -202,7 +204,7 @@ mod tests {
},
redis: None,
auth: None,
workers: Default::default(),
workers: Workers::default(),
mailer: None,
settings: None,
},
Expand All @@ -224,7 +226,7 @@ mod tests {
let key = create_key();
let mut headers = HeaderMap::new();
headers.insert(COOKIE, HeaderValue::from_static(""));
let jar = OAuth2PrivateCookieJar::from_headers(&headers, key.clone());
let jar = OAuth2PrivateCookieJar::from_headers(&headers, key);

let cookie_name = "test_cookie";
let cookie_value = "test_value";
Expand All @@ -248,7 +250,7 @@ mod tests {
let key = create_key();
let mut headers = HeaderMap::new();
headers.insert(COOKIE, HeaderValue::from_static(""));
let jar = OAuth2PrivateCookieJar::from_headers(&headers, key.clone());
let jar = OAuth2PrivateCookieJar::from_headers(&headers, key);

let cookie_name = "test_cookie";
let cookie_value = "test_value";
Expand Down Expand Up @@ -293,7 +295,7 @@ mod tests {
// Simulate receiving a request with the encrypted cookie
let mut headers = HeaderMap::new();
headers.insert("cookie", encrypted_cookie_value.parse().unwrap());
let private_jar = PrivateCookieJar::from_headers(&HeaderMap::new(), key.clone());
let private_jar = PrivateCookieJar::from_headers(&HeaderMap::new(), key);
let mut original_cookie = None;
for cookie in cookies_from_request(&headers) {
if let Some(cookie) = private_jar.decrypt(cookie) {
Expand All @@ -314,7 +316,7 @@ mod tests {
let key = create_key();
let mut headers = HeaderMap::new();
headers.insert(COOKIE, HeaderValue::from_static(""));
let jar = OAuth2PrivateCookieJar::from_headers(&headers, key.clone());
let jar = OAuth2PrivateCookieJar::from_headers(&headers, key);

let cookie_name = "test_cookie";
let cookie_value = "test_value";
Expand Down
38 changes: 28 additions & 10 deletions src/controllers/oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::sync::MutexGuard;

use crate::controllers::middleware::OAuth2PrivateCookieJarTrait;
use crate::controllers::middleware::{OAuth2CookieUser, OAuth2PrivateCookieJar};
use crate::grants::authorization_code::AuthorizationCodeGrantTrait;
use crate::grants::authorization_code::GrantTrait;
use crate::models::oauth2_sessions::OAuth2SessionsTrait;
use crate::models::users::OAuth2UserTrait;

Expand All @@ -32,7 +32,7 @@ pub struct AuthParams {
/// * `String` - The authorization URL
pub async fn get_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send + 'static>(
session: Session<T>,
oauth2_client: &mut MutexGuard<'_, dyn AuthorizationCodeGrantTrait>,
oauth2_client: &mut MutexGuard<'_, dyn GrantTrait>,
) -> String {
let (auth_url, csrf_token) = oauth2_client.get_authorization_url();
session.set("CSRF_TOKEN", csrf_token.secret().to_owned());
Expand All @@ -43,7 +43,7 @@ pub async fn get_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send
/// then upsert the user and the session and set the token in a short live
/// cookie Lastly, it will redirect the user to the protected URL
/// # Generics
/// * `T` - The user profile, should implement `DeserializeOwned`
/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
/// * `W` - The database pool
Expand All @@ -58,7 +58,7 @@ pub async fn get_authorization_url<T: DatabasePool + Clone + Debug + Sync + Send
/// # Errors
/// * `loco_rs::errors::Error`
pub async fn callback<
T: DeserializeOwned,
T: DeserializeOwned + Send,
U: OAuth2UserTrait<T> + ModelTrait,
V: OAuth2SessionsTrait<U>,
W: DatabasePool + Clone + Debug + Sync + Send + 'static,
Expand All @@ -68,7 +68,7 @@ pub async fn callback<
params: AuthParams,
// Extract the private cookie jar from the request
jar: OAuth2PrivateCookieJar,
client: &mut MutexGuard<'_, dyn AuthorizationCodeGrantTrait>,
client: &mut MutexGuard<'_, dyn GrantTrait>,
) -> Result<impl IntoResponse> {
// Get the CSRF token from the session
let csrf_token = session
Expand All @@ -80,7 +80,10 @@ pub async fn callback<
.await
.map_err(|e| Error::BadRequest(e.to_string()))?;
// Get the user profile
let profile = profile.json::<T>().await.unwrap();
let profile = profile.json::<T>().await.map_err(|e| {
tracing::error!("Error getting profile: {:?}", e);
Error::InternalServerError
})?;
let user = U::upsert_with_oauth(&ctx.db, &profile)
.await
.map_err(|_e| {
Expand Down Expand Up @@ -133,6 +136,7 @@ pub async fn google_authorization_url<T: DatabasePool + Clone + Debug + Sync + S
Error::InternalServerError
})?;
let auth_url = get_authorization_url(session, &mut client).await;
drop(client);
Ok(auth_url)
}

Expand All @@ -141,7 +145,7 @@ pub async fn google_authorization_url<T: DatabasePool + Clone + Debug + Sync + S
/// then upsert the user and the session and set the token in a short live
/// cookie Lastly, it will redirect the user to the protected URL
/// # Generics
/// * `T` - The user profile, should implement `DeserializeOwned`
/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
/// # Arguments
Expand All @@ -156,7 +160,7 @@ pub async fn google_authorization_url<T: DatabasePool + Clone + Debug + Sync + S
/// # Errors
/// * `loco_rs::errors::Error`
pub async fn google_callback<
T: DeserializeOwned,
T: DeserializeOwned + Send,
U: OAuth2UserTrait<T> + ModelTrait,
V: OAuth2SessionsTrait<U>,
W: DatabasePool + Clone + Debug + Sync + Send + 'static,
Expand All @@ -176,16 +180,30 @@ pub async fn google_callback<
Error::InternalServerError
})?;
let response = callback::<T, U, V, W>(ctx, session, params, jar, &mut client).await?;
drop(client);
Ok(response)
}

/// The protected URL for the `OAuth2` flow
/// This will return a message indicating that the user is protected
///
/// # Generics
/// * `T` - The user profile, should implement `DeserializeOwned` and `Send`
/// * `U` - The user model, should implement `OAuth2UserTrait` and `ModelTrait`
/// * `V` - The session model, should implement `OAuth2SessionsTrait` and `ModelTrait`
/// # Arguments
/// * `user` - The `OAuth2CookieUser` that holds the user and the session
/// # Returns
/// The response with the message indicating that the user is protected
/// # Errors
/// * `loco_rs::errors::Error` - When the user cannot be retrieved
pub async fn protected<
T: DeserializeOwned,
T: DeserializeOwned + Send,
U: OAuth2UserTrait<T> + ModelTrait,
V: OAuth2SessionsTrait<U> + ModelTrait,
>(
user: OAuth2CookieUser<T, U, V>,
) -> Result<impl IntoResponse> {
let _user = user.as_ref();
Ok(format!("You are protected!"))
Ok("You are protected!".to_string())
}
Loading

0 comments on commit 77923fb

Please sign in to comment.