From 5f55f7102c57ffaee4c7be8bfb6cd471210944a2 Mon Sep 17 00:00:00 2001 From: Esteban Borai Date: Sat, 20 Jan 2024 18:26:17 -0300 Subject: [PATCH] feat: token aware routes with middleware (#14) Adds support for Bearer tokens in server by implementing a middleware that fetches the account using the token specified in the HTTP headers of the request as: ``` Authorization: Bearer ``` --- Cargo.toml | 2 +- crates/core/src/account/service.rs | 2 +- crates/server/Cargo.toml | 1 + crates/server/src/bin/main.rs | 5 +- crates/server/src/lib.rs | 15 +- crates/server/src/router/api/mod.rs | 51 ++++-- .../server/src/router/api/v1/account/email.rs | 6 +- .../server/src/router/api/v1/account/login.rs | 12 +- .../server/src/router/api/v1/account/mod.rs | 22 +-- .../server/src/router/api/v1/account/root.rs | 5 +- .../src/router/api/v1/account/session.rs | 46 ++++++ .../src/router/api/v1/account/verify_code.rs | 5 +- .../api/v1/account/verify_code_email.rs | 5 +- crates/server/src/router/api/v1/mod.rs | 4 +- crates/server/src/router/middleware/auth.rs | 55 +++++++ crates/server/src/router/middleware/mod.rs | 3 + crates/server/src/router/mod.rs | 8 +- crates/test/src/server/api/v1/account/mod.rs | 1 + .../test/src/server/api/v1/account/session.rs | 148 ++++++++++++++++++ crates/test/src/tools/http.rs | 22 ++- rust-toolchain.toml | 3 + 21 files changed, 354 insertions(+), 67 deletions(-) create mode 100644 crates/server/src/router/api/v1/account/session.rs create mode 100644 crates/server/src/router/middleware/auth.rs create mode 100644 crates/server/src/router/middleware/mod.rs create mode 100644 crates/test/src/server/api/v1/account/session.rs create mode 100644 rust-toolchain.toml diff --git a/Cargo.toml b/Cargo.toml index 82eac10..ba38dff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ resolver = "1" [workspace.dependencies] anyhow = "1.0.75" -axum = "0.6.19" +axum = { version = "0.7.4", features = ["tokio"] } dotenv = "0.15.0" http = "0.2.11" reqwest = "0.11.22" diff --git a/crates/core/src/account/service.rs b/crates/core/src/account/service.rs index 4a8f821..edbc95d 100644 --- a/crates/core/src/account/service.rs +++ b/crates/core/src/account/service.rs @@ -249,7 +249,7 @@ impl AccountService { Ok(credentials.access_token) } - pub async fn whoami(&self, access_token: Secret) -> Result { + pub async fn whoami(&self, access_token: &Secret) -> Result { let session = Session::get(&self.admin, access_token.to_string()) .await .map_err(|err| { diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index 54fb8b5..e185928 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -19,6 +19,7 @@ serde_json = "1.0.108" axum = { workspace = true, features = ["tokio"] } anyhow = { workspace = true } dotenv = { workspace = true } +http = { workspace = true } serde = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } tracing = { workspace = true } diff --git a/crates/server/src/bin/main.rs b/crates/server/src/bin/main.rs index a29b855..16a31bb 100644 --- a/crates/server/src/bin/main.rs +++ b/crates/server/src/bin/main.rs @@ -1,7 +1,8 @@ -use std::net::{SocketAddr, TcpListener}; +use std::net::SocketAddr; use anyhow::Result; use dotenv::dotenv; +use tokio::net::TcpListener; #[tokio::main] async fn main() -> Result<()> { @@ -12,7 +13,7 @@ async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let tcp = TcpListener::bind(addr)?; + let tcp = TcpListener::bind(addr).await?; tracing::info!("Listening on {}", addr); diff --git a/crates/server/src/lib.rs b/crates/server/src/lib.rs index 726fcf3..1fc5da7 100644 --- a/crates/server/src/lib.rs +++ b/crates/server/src/lib.rs @@ -1,6 +1,5 @@ -use std::net::TcpListener; - use anyhow::Result; +use tokio::net::TcpListener; pub mod config; pub mod router; @@ -10,15 +9,15 @@ use crate::config::ServerConfig; use crate::router::make_router; use crate::services::Services; -pub async fn serve(tcp: TcpListener) -> Result<()> { +pub async fn serve(listener: TcpListener) -> Result<()> { let config = ServerConfig::from_env(); let services = Services::shared(config).await?; - let router = make_router(); - let router = router.with_state(services); + let router = make_router(services); - axum::Server::from_tcp(tcp)? - .serve(router.into_make_service()) - .await?; + if let Err(err) = axum::serve(listener, router.into_make_service()).await { + tracing::error!(%err, "Failed to initialize the server"); + panic!("An error ocurred running the server!"); + } Ok(()) } diff --git a/crates/server/src/router/api/mod.rs b/crates/server/src/router/api/mod.rs index 8a0ce75..dd2075c 100644 --- a/crates/server/src/router/api/mod.rs +++ b/crates/server/src/router/api/mod.rs @@ -1,46 +1,61 @@ pub mod v1; -use axum::http::StatusCode; use axum::response::IntoResponse; use axum::Json; use axum::Router; +use http::StatusCode; +use serde::Deserialize; use serde::Serialize; use commune::error::HttpStatusCode; -use crate::services::SharedServices; - pub struct Api; impl Api { - pub fn routes() -> Router { - Router::new().nest("/v1", v1::V1::routes()) + pub fn routes() -> Router { + Router::new().nest("/api", Router::new().nest("/v1", v1::V1::routes())) } } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct ApiError { - message: String, - code: &'static str, + pub message: String, + pub code: String, #[serde(skip)] - status: StatusCode, + pub status: StatusCode, } impl ApiError { - pub fn new(message: String, code: &'static str, status: StatusCode) -> Self { + pub fn new(message: String, code: String, status: StatusCode) -> Self { Self { message, code, status, } } + + pub fn unauthorized() -> Self { + Self::new( + "You must be authenticated to access this resource".to_string(), + "UNAUTHORIZED".to_string(), + StatusCode::UNAUTHORIZED, + ) + } + + pub fn internal_server_error() -> Self { + Self::new( + "Internal server error".to_string(), + "INTERNAL_SERVER_ERROR".to_string(), + StatusCode::INTERNAL_SERVER_ERROR, + ) + } } impl From for ApiError { fn from(err: commune::error::Error) -> Self { Self { message: err.to_string(), - code: err.error_code(), + code: err.error_code().to_string(), status: err.status_code(), } } @@ -57,7 +72,7 @@ impl From for ApiError { fn from(err: anyhow::Error) -> Self { Self { message: err.to_string(), - code: "UNKNOWN_ERROR", + code: "UNKNOWN_ERROR".to_string(), status: StatusCode::INTERNAL_SERVER_ERROR, } } @@ -65,10 +80,14 @@ impl From for ApiError { impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { - let status = self.status; - let mut response = Json(self).into_response(); + if let Ok(status) = axum::http::StatusCode::from_u16(self.status.as_u16()) { + let mut response = Json(self).into_response(); + + *response.status_mut() = status; + return response; + } - *response.status_mut() = status; - response + tracing::error!(status=%self.status, "Failed to convert status code to http::StatusCode"); + ApiError::internal_server_error().into_response() } } diff --git a/crates/server/src/router/api/v1/account/email.rs b/crates/server/src/router/api/v1/account/email.rs index 1dc94c8..3fd2975 100644 --- a/crates/server/src/router/api/v1/account/email.rs +++ b/crates/server/src/router/api/v1/account/email.rs @@ -1,7 +1,7 @@ -use axum::extract::{Path, State}; +use axum::extract::Path; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use axum::Json; +use axum::{Extension, Json}; use serde::{Deserialize, Serialize}; use tracing::instrument; @@ -10,7 +10,7 @@ use crate::services::SharedServices; #[instrument(skip(services))] pub async fn handler( - State(services): State, + Extension(services): Extension, Path(email): Path, ) -> Response { match services.commune.account.is_email_available(&email).await { diff --git a/crates/server/src/router/api/v1/account/login.rs b/crates/server/src/router/api/v1/account/login.rs index 1bcd16b..818fa90 100644 --- a/crates/server/src/router/api/v1/account/login.rs +++ b/crates/server/src/router/api/v1/account/login.rs @@ -1,7 +1,6 @@ -use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use axum::Json; +use axum::{Extension, Json}; use commune::Error; use serde::{Deserialize, Serialize}; use tracing::instrument; @@ -15,7 +14,7 @@ use super::root::{AccountMatrixCredentials, AccountSpace}; #[instrument(skip(services, payload))] pub async fn handler( - State(services): State, + Extension(services): Extension, Json(payload): Json, ) -> Response { let login_credentials = LoginCredentials::from(payload); @@ -28,12 +27,7 @@ pub async fn handler( .into_response(); }; - match services - .commune - .account - .whoami(tokens.access_token.clone()) - .await - { + match services.commune.account.whoami(&tokens.access_token).await { Ok(account) => { let mut response = Json(AccountLoginResponse { access_token: tokens.access_token.to_string(), diff --git a/crates/server/src/router/api/v1/account/mod.rs b/crates/server/src/router/api/v1/account/mod.rs index a6714f8..b487959 100644 --- a/crates/server/src/router/api/v1/account/mod.rs +++ b/crates/server/src/router/api/v1/account/mod.rs @@ -1,26 +1,30 @@ pub mod email; pub mod login; pub mod root; +pub mod session; pub mod verify_code; pub mod verify_code_email; use axum::routing::{get, post}; -use axum::Router; +use axum::{middleware, Router}; -use crate::services::SharedServices; +use crate::router::middleware::auth; pub struct Account; impl Account { - pub fn routes() -> Router { - let verify = Router::new() - .route("/code", post(verify_code::handler)) - .route("/code/email", post(verify_code_email::handler)); - + pub fn routes() -> Router { Router::new() + .route("/session", get(session::handler)) + .route_layer(middleware::from_fn(auth)) .route("/", post(root::handler)) - .route("/email/:email", get(email::handler)) .route("/login", post(login::handler)) - .nest("/verify", verify) + .route("/email/:email", get(email::handler)) + .nest( + "/verify", + Router::new() + .route("/code", post(verify_code::handler)) + .route("/code/email", post(verify_code_email::handler)), + ) } } diff --git a/crates/server/src/router/api/v1/account/root.rs b/crates/server/src/router/api/v1/account/root.rs index 1460da7..56af72d 100644 --- a/crates/server/src/router/api/v1/account/root.rs +++ b/crates/server/src/router/api/v1/account/root.rs @@ -1,7 +1,6 @@ -use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use axum::Json; +use axum::{Extension, Json}; use serde::{Deserialize, Serialize}; use tracing::instrument; @@ -16,7 +15,7 @@ use crate::services::SharedServices; #[instrument(skip(services, payload))] pub async fn handler( - State(services): State, + Extension(services): Extension, Json(payload): Json, ) -> Response { let dto = CreateAccountDto::from(payload); diff --git a/crates/server/src/router/api/v1/account/session.rs b/crates/server/src/router/api/v1/account/session.rs new file mode 100644 index 0000000..7b77ba1 --- /dev/null +++ b/crates/server/src/router/api/v1/account/session.rs @@ -0,0 +1,46 @@ +use axum::response::{IntoResponse, Response}; +use axum::{Extension, Json}; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +use commune::account::model::Account; + +use crate::router::middleware::AccessToken; + +use super::root::{AccountMatrixCredentials, AccountSpace}; + +#[instrument(skip(account))] +pub async fn handler( + Extension(account): Extension, + Extension(access_token): Extension, +) -> Response { + let response = Json(AccountSessionResponse { + credentials: AccountMatrixCredentials { + username: account.username, + display_name: account.display_name, + avatar_url: account.avatar_url, + access_token: access_token.to_string(), + matrix_access_token: access_token.to_string(), + matrix_user_id: account.user_id.to_string(), + matrix_device_id: String::new(), + user_space_id: String::new(), + email: account.email, + age: account.age, + admin: account.admin, + verified: account.verified, + }, + rooms: vec![], + spaces: vec![], + valid: true, + }); + + response.into_response() +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AccountSessionResponse { + pub credentials: AccountMatrixCredentials, + pub rooms: Vec, + pub spaces: Vec, + pub valid: bool, +} diff --git a/crates/server/src/router/api/v1/account/verify_code.rs b/crates/server/src/router/api/v1/account/verify_code.rs index 1d0e0fc..f17547e 100644 --- a/crates/server/src/router/api/v1/account/verify_code.rs +++ b/crates/server/src/router/api/v1/account/verify_code.rs @@ -1,7 +1,6 @@ -use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use axum::Json; +use axum::{Extension, Json}; use commune::account::error::AccountErrorCode; use commune::Error; use serde::{Deserialize, Serialize}; @@ -15,7 +14,7 @@ use crate::services::SharedServices; #[instrument(skip(services, payload))] pub async fn handler( - State(services): State, + Extension(services): Extension, Json(payload): Json, ) -> Response { let dto = SendCodeDto::from(payload); diff --git a/crates/server/src/router/api/v1/account/verify_code_email.rs b/crates/server/src/router/api/v1/account/verify_code_email.rs index dd711c9..898ed2a 100644 --- a/crates/server/src/router/api/v1/account/verify_code_email.rs +++ b/crates/server/src/router/api/v1/account/verify_code_email.rs @@ -1,7 +1,6 @@ -use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use axum::Json; +use axum::{Extension, Json}; use commune::account::error::AccountErrorCode; use commune::util::secret::Secret; use commune::Error; @@ -16,7 +15,7 @@ use crate::services::SharedServices; #[instrument(skip(services, payload))] pub async fn handler( - State(services): State, + Extension(services): Extension, Json(payload): Json, ) -> Response { let dto = VerifyCodeDto::from(payload); diff --git a/crates/server/src/router/api/v1/mod.rs b/crates/server/src/router/api/v1/mod.rs index f57c5bc..360418c 100644 --- a/crates/server/src/router/api/v1/mod.rs +++ b/crates/server/src/router/api/v1/mod.rs @@ -2,12 +2,10 @@ pub mod account; use axum::Router; -use crate::services::SharedServices; - pub struct V1; impl V1 { - pub fn routes() -> Router { + pub fn routes() -> Router { Router::new().nest("/account", account::Account::routes()) } } diff --git a/crates/server/src/router/middleware/auth.rs b/crates/server/src/router/middleware/auth.rs new file mode 100644 index 0000000..977457b --- /dev/null +++ b/crates/server/src/router/middleware/auth.rs @@ -0,0 +1,55 @@ +use axum::body::Body; +use axum::http::{header::AUTHORIZATION, Request}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; + +use commune::util::secret::Secret; + +use crate::router::api::ApiError; +use crate::services::SharedServices; + +#[derive(Debug, Clone)] +pub struct AccessToken(Secret); + +impl ToString for AccessToken { + fn to_string(&self) -> String { + self.0.to_string() + } +} + +pub async fn auth(mut request: Request, next: Next) -> Result { + let access_token = request + .headers() + .get(AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.strip_prefix("Bearer ")) + .ok_or_else(|| { + tracing::warn!("No access token provided"); + ApiError::unauthorized().into_response() + })? + .to_owned(); + + let services = request + .extensions() + .get::() + .ok_or_else(|| { + tracing::error!("SharedServices not found in request extensions"); + ApiError::internal_server_error().into_response() + })?; + + let access_token = Secret::new(access_token); + let user = services + .commune + .account + .whoami(&access_token) + .await + .map_err(|err| { + tracing::error!("Failed to validate token: {}", err); + ApiError::internal_server_error().into_response() + })?; + + request.extensions_mut().insert(user); + request.extensions_mut().insert(AccessToken(access_token)); + + Ok(next.run(request).await) +} diff --git a/crates/server/src/router/middleware/mod.rs b/crates/server/src/router/middleware/mod.rs new file mode 100644 index 0000000..b41b0fe --- /dev/null +++ b/crates/server/src/router/middleware/mod.rs @@ -0,0 +1,3 @@ +mod auth; + +pub use auth::{auth, AccessToken}; diff --git a/crates/server/src/router/mod.rs b/crates/server/src/router/mod.rs index 93f6845..fea3442 100644 --- a/crates/server/src/router/mod.rs +++ b/crates/server/src/router/mod.rs @@ -1,9 +1,13 @@ pub mod api; +pub mod middleware; +use axum::Extension; use axum::Router; use crate::services::SharedServices; -pub fn make_router() -> Router { - Router::new().nest("/api", api::Api::routes()) +pub fn make_router(service: SharedServices) -> Router { + Router::new() + .merge(api::Api::routes()) + .layer(Extension(service)) } diff --git a/crates/test/src/server/api/v1/account/mod.rs b/crates/test/src/server/api/v1/account/mod.rs index 5cdf6d5..c71f243 100644 --- a/crates/test/src/server/api/v1/account/mod.rs +++ b/crates/test/src/server/api/v1/account/mod.rs @@ -1,2 +1,3 @@ mod login; mod root; +mod session; diff --git a/crates/test/src/server/api/v1/account/session.rs b/crates/test/src/server/api/v1/account/session.rs new file mode 100644 index 0000000..d375a6e --- /dev/null +++ b/crates/test/src/server/api/v1/account/session.rs @@ -0,0 +1,148 @@ +use commune_server::router::api::v1::account::root::AccountRegisterPayload; +use commune_server::router::api::v1::account::session::AccountSessionResponse; +use commune_server::router::api::ApiError; +use fake::faker::internet::en::{FreeEmail, Password}; +use fake::Fake; +use reqwest::StatusCode; +use scraper::Selector; +use uuid::Uuid; + +use commune::util::secret::Secret; +use commune_server::router::api::v1::account::login::{AccountLoginPayload, AccountLoginResponse}; +use commune_server::router::api::v1::account::verify_code::{ + AccountVerifyCodePayload, VerifyCodeResponse, +}; +use commune_server::router::api::v1::account::verify_code_email::{ + AccountVerifyCodeEmailPayload, VerifyCodeEmailResponse, +}; + +use crate::tools::http::HttpClient; +use crate::tools::maildev::MailDevClient; + +#[tokio::test] +async fn retrieves_session_user_from_token() { + let http_client = HttpClient::new().await; + let session = Uuid::new_v4(); + let email: String = FreeEmail().fake(); + let verify_code_pld = AccountVerifyCodePayload { + email: email.clone(), + session, + }; + let verify_code_res = http_client + .post("/api/v1/account/verify/code") + .json(&verify_code_pld) + .send() + .await; + let verify_code = verify_code_res.json::().await; + + assert!(verify_code.sent, "should return true for sent"); + + let maildev = MailDevClient::new(); + let mail = maildev.latest().await.unwrap().unwrap(); + let html = mail.html(); + let code_sel = Selector::parse("#code").unwrap(); + let mut code_el = html.select(&code_sel); + let code = code_el.next().unwrap().inner_html(); + let verify_code_email_pld = AccountVerifyCodeEmailPayload { + email: email.clone(), + code: Secret::new(code.clone()), + session, + }; + + let verify_code_res = http_client + .post("/api/v1/account/verify/code/email") + .json(&verify_code_email_pld) + .send() + .await; + let verify_code_email = verify_code_res.json::().await; + + assert!(verify_code_email.valid, "should return true for valid"); + + let username: String = (10..12).fake(); + let username = username.to_ascii_lowercase(); + let password: String = Password(14..20).fake(); + let request_payload = AccountRegisterPayload { + username: username.clone(), + password: password.clone(), + email: email.clone(), + code, + session, + }; + let response = http_client + .post("/api/v1/account") + .json(&request_payload) + .send() + .await; + + assert_eq!( + response.status(), + StatusCode::CREATED, + "should return 201 for successful registration" + ); + + let response = http_client + .post("/api/v1/account/login") + .json(&AccountLoginPayload { + username: username.clone(), + password, + }) + .send() + .await; + let response_status = response.status(); + let response_payload = response.json::().await; + + assert_eq!( + response_status, + StatusCode::OK, + "should return 200 for successful login" + ); + assert!(!response_payload.access_token.is_empty()); + + let session_res = http_client + .get("/api/v1/account/session") + .token(response_payload.access_token) + .send() + .await; + let session_res_status = session_res.status(); + let session_res_payload = session_res.json::().await; + + assert_eq!( + session_res_status, + StatusCode::OK, + "should return 200 for successful session" + ); + assert!(session_res_payload + .credentials + .username + .starts_with(&format!("@{}", username))); + assert_eq!( + session_res_payload.credentials.email, email, + "should return email" + ); + assert!( + session_res_payload.credentials.verified, + "should return verified" + ); + assert!( + !session_res_payload.credentials.admin, + "should return admin" + ); +} + +#[tokio::test] +async fn kicks_users_with_no_token_specified() { + let http_client = HttpClient::new().await; + let session_res = http_client.get("/api/v1/account/session").send().await; + let session_res_status = session_res.status(); + let session_res_payload = session_res.json::().await; + + assert_eq!(session_res_status, StatusCode::UNAUTHORIZED.as_u16(),); + assert_eq!( + session_res_payload.code, "UNAUTHORIZED", + "should return UNAUTHORIZED" + ); + assert_eq!( + session_res_payload.message, + "You must be authenticated to access this resource", + ); +} diff --git a/crates/test/src/tools/http.rs b/crates/test/src/tools/http.rs index 0a8aa76..fcd92e7 100644 --- a/crates/test/src/tools/http.rs +++ b/crates/test/src/tools/http.rs @@ -1,7 +1,8 @@ use std::net::SocketAddr; use dotenv::dotenv; -use reqwest::{Client, StatusCode}; +use reqwest::{header::AUTHORIZATION, Client, StatusCode}; +use tokio::net::TcpListener; use commune_server::serve; @@ -14,9 +15,7 @@ impl HttpClient { pub(crate) async fn new() -> Self { dotenv().ok(); - let tcp = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - tcp.set_nonblocking(true) - .expect("Failed to set non-blocking mode"); + let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = tcp.local_addr().unwrap(); tokio::spawn(async move { @@ -31,6 +30,12 @@ impl HttpClient { HttpClient { client, addr } } + pub(crate) fn get(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.get(self.path(url)), + } + } + pub(crate) fn post(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.post(self.path(url)), @@ -53,6 +58,15 @@ impl RequestBuilder { } } + pub(crate) fn token(mut self, token: impl AsRef) -> Self { + let next = self + .builder + .header(AUTHORIZATION, format!("Bearer {}", token.as_ref())); + + self.builder = next; + self + } + pub(crate) fn json(mut self, json: &T) -> Self where T: serde::Serialize, diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..7cbc6ac --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "1.75.0" +components = [ "rustfmt", "clippy" ]